From 9a6ee6bf6714f9d78c21ebe554d18fb534f3f40a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 1 Apr 2025 19:50:59 +0000 Subject: [PATCH 001/171] moe refactoring Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 4 + .../layers/fused_moe/modular_kernel.py | 99 +++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/modular_kernel.py diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a209715ede77..e00878276942 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1420,6 +1420,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) + if True: + intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K) + intermediate_cache3.mul_(curr_topk_weights.view(tokens_in_chunk, -1, 1)) + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py new file mode 100644 index 000000000000..a688ae41a751 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -0,0 +1,99 @@ +from abc import ABC, abstractmethod +from typing import Optional, Tuple +import torch + + +class FusedMoEDispatchQuantize(ABC): + def __init__(self): + pass + + @abstractmethod + def apply( + self, + hidden_states, + hidden_states_scales, + topk_ids, + num_experts, + expert_map, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # returns (hidden_states, scales, sorted_token_ids, expert_ids, inv_perm) # make more abstract? + raise NotImplementedError + + +# store weights, etc. here +class FusedMoEExperts(ABC): + def __init__(self): + pass + + @abstractmethod + def apply(self): + raise NotImplementedError + + +class FusedMoEUnpermuteCombine(ABC): + def __init__(self): + pass + + @abstractmethod + def apply( + self, + out, + hidden_states, + topk_weights, + topk, + inv_perm, + ) -> torch>Tensor: + raise NotImplementedError + + +class ModularFusedMoEKernel(torch.nn.Module): # should this be a module? + def __init__( + self, + dispatch: FusedMoEDispatchQuantize, + fused_experts: FusedMoEExperts, + combine: FusedMoEUnpermuteCombine, + ): + self.dispatch = dispatch + self.fused_experts = fused_experts + self.combine = combine + + def forward( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + self.dispatch() + + fused_out = self.fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + ) + + self.combine(hidden_states, fused_out) + return hidden_states From 82188ddec3648061abece2eaf3949d80716131cd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 1 Apr 2025 22:06:39 +0000 Subject: [PATCH 002/171] module deepgemm moe working Signed-off-by: Bill Nell --- tests/kernels/quantization/test_block_fp8.py | 16 +- .../layers/fused_moe/deep_gemm_moe.py | 139 ++++++++++++++++++ .../layers/fused_moe/modular_kernel.py | 112 ++++++++++---- 3 files changed, 235 insertions(+), 32 deletions(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index c57e39f42506..f00745d91cc1 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,7 +11,8 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - deep_gemm_moe_fp8) + deep_gemm_moe_fp8, + modular_deep_gemm_fused_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -380,12 +381,12 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): block_size = [block_m, block_m] dtype = torch.bfloat16 - # only aligned sizes + # only aligned sizes TODO: use _valid_deep_gemm here instead? if (N % block_m != 0 or K % block_m != 0 or topk > E): pytest.skip( f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") - if N <= 512: + if False and N <= 512: pytest.skip("Skipping N <= 512 until performance issues solved.") vllm_config = VllmConfig() @@ -426,6 +427,13 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + if True: + dgm = modular_deep_gemm_fused_moe_fp8() + def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): + return dgm(a, w1, w2, topk_weights, topk_ids, w1_scale=w1_s, w2_scale=w2_s) + else: + deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 + # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): if M >= 128: @@ -437,7 +445,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) - out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 353c8cc9d59f..1c5212e943e6 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -7,6 +7,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, @@ -292,3 +293,141 @@ def deep_gemm_moe_fp8( workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights) return out_hidden_states + + +class DeepGemmDispatch(mk.FusedMoEDispatchQuantize): + def __init__(self): + super().__init__() + import deep_gemm as dg + block_m = dg.get_m_alignment_for_contiguous_layout() + self.block_shape = [block_m, block_m] + + def apply( + self, + hidden_states: torch.Tensor, + hidden_states_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: + q_hidden_states, q_hidden_states_scale = _fp8_quantize( + hidden_states, + hidden_states_scale, + self.block_shape, + ) + + q_hidden_states, q_hidden_states_scale, _, expert_ids, inv_perm = _moe_permute( + q_hidden_states, + q_hidden_states_scale, + topk_ids, + num_experts, + expert_map, + self.block_shape[0], + ) + + return q_hidden_states, q_hidden_states_scale, expert_ids, inv_perm + + +class DeepGemmExperts(mk.FusedMoEExperts): + def __init__(self): + super().__init__() + import deep_gemm as dg + block_m = dg.get_m_alignment_for_contiguous_layout() + self.block_shape = [block_m, block_m] + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + num_experts: int + ) -> Tuple[int, int]: + block_m = self.block_shape[0] + M_sum = (M * topk) + num_experts * (block_m - 1) + M_sum = round_up(M_sum, block_m) + workspace1 = M_sum * max(N, K) + workspace2 = M_sum * (N // 2) + # return tuples???? + return (workspace1, workspace2) + + def apply( + self, + q_hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + inplace: bool, + activation: str, + expert_ids: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: # or None? assume inplace? + import deep_gemm as dg + + # chunking in here or in ModularFusedMoEKernel? ignore for now + M_sum = q_hidden_states.shape[0] # double check this + E, N, _ = w1.shape + _, K, _ = w2.shape + + #print(f"M_sum = {M_sum}") + + workspace1 = _resize_cache(workspace13, (M_sum, N)) + workspace2 = _resize_cache(workspace2, (M_sum, N // 2)) + workspace3 = _resize_cache(workspace13, (M_sum, K)) + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (q_hidden_states, a1_scale), (w1, w1_scale), + workspace1, + expert_ids) + + if activation == "silu": + torch.ops._C.silu_and_mul(workspace2, + workspace1.view(-1, N)) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(workspace2, + workspace1.view(-1, N)) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + a2q_scale: Optional[torch.Tensor] = None + + qworkspace2, a2q_scale = _fp8_quantize( + workspace2, a2_scale, self.block_shape) + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (qworkspace2, a2q_scale), (w2, w2_scale), + workspace3, expert_ids) + + return workspace3 + + +class DeepGemmUnpermuteCombine(mk.FusedMoEUnpermuteCombine): + def __init__(self): + super().__init__() + + def apply( + self, + out: torch.Tensor, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + inv_perm: Optional[torch.Tensor], + ) -> torch.Tensor: + _moe_unpermute_and_reduce( + out, + hidden_states, + inv_perm, + topk_weights + ) + return out + + +def modular_deep_gemm_fused_moe_fp8() -> mk.ModularFusedMoEKernel: + return mk.ModularFusedMoEKernel( + DeepGemmDispatch(), + DeepGemmExperts(), + DeepGemmUnpermuteCombine(), + ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a688ae41a751..5866129eccbc 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -10,13 +10,13 @@ def __init__(self): @abstractmethod def apply( self, - hidden_states, - hidden_states_scales, - topk_ids, - num_experts, - expert_map, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # returns (hidden_states, scales, sorted_token_ids, expert_ids, inv_perm) # make more abstract? + hidden_states: torch.Tensor, + hidden_states_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: + # returns (hidden_states, scales, expert_ids, inv_perm) # make more abstract? raise NotImplementedError @@ -26,7 +26,32 @@ def __init__(self): pass @abstractmethod - def apply(self): + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + num_experts: int + ) -> Tuple[int, int]: + raise NotImplementedError + + @abstractmethod + def apply( + self, + out: torch.Tensor, + q_hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + activation: str, + expert_ids: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + q_hidden_states_scale: Optional[torch.Tensor], + hidden_states_scale_2: Optional[torch.Tensor], + workspace1: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: # or None? assume inplace? raise NotImplementedError @@ -37,12 +62,11 @@ def __init__(self): @abstractmethod def apply( self, - out, - hidden_states, - topk_weights, - topk, - inv_perm, - ) -> torch>Tensor: + out: torch.Tensor, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + inv_perm: Optional[torch.Tensor], + ) -> torch.Tensor: raise NotImplementedError @@ -53,6 +77,7 @@ def __init__( fused_experts: FusedMoEExperts, combine: FusedMoEUnpermuteCombine, ): + super().__init__() self.dispatch = dispatch self.fused_experts = fused_experts self.combine = combine @@ -75,25 +100,56 @@ def forward( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - self.dispatch() + M, _ = hidden_states.shape + E, N, _ = w1.shape + K = w2.shape[1] + if global_num_experts == -1: + global_num_experts = E + top_k = topk_ids.shape[1] - fused_out = self.fused_experts( + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + #print(f"TKN = {topk_ids.numel()} {M*top_k}") + + workspace13_shape, workspace2_shape = self.fused_experts.workspace_shapes(M, N, K, top_k, global_num_experts) + + # We can reuse the memory between cache1 and cache3 because by the time + # we need cache3, we're done with cache1 + workspace13 = torch.empty(workspace13_shape, + device=hidden_states.device, + dtype=hidden_states.dtype) + workspace2 = torch.empty(workspace2_shape, + device=hidden_states.device, + dtype=hidden_states.dtype) + + #print(f"\nbefore M = {hidden_states.shape[0]}") + + hidden_states, a1_scale, expert_ids, inv_perm = self.dispatch.apply( + hidden_states, + a1_scale, + topk_ids, + global_num_experts, + expert_map, + ) + + #print(f"after M = {hidden_states.shape[0]}") + + fused_out = self.fused_experts.apply( hidden_states, w1, w2, - topk_weights, - topk_ids, inplace, activation, - global_num_experts, - expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, - a1_scale, - a2_scale, + expert_ids, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + workspace13=workspace13, + workspace2=workspace2, ) - self.combine(hidden_states, fused_out) - return hidden_states + return self.combine.apply(out_hidden_states, fused_out, topk_weights, inv_perm) From fea4fbfb5a762e88389bf82d81e1044e8b106376 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 01:13:38 +0000 Subject: [PATCH 003/171] working deep gemm, wip cutlass Signed-off-by: Bill Nell --- tests/kernels/quantization/test_block_fp8.py | 5 + .../layers/fused_moe/cutlass_moe.py | 180 ++++++++++++++++++ .../layers/fused_moe/deep_gemm_moe.py | 2 +- .../layers/fused_moe/modular_kernel.py | 2 + 4 files changed, 188 insertions(+), 1 deletion(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index f00745d91cc1..aa1e5b52073b 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + _valid_deep_gemm, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk @@ -403,6 +404,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) +# if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): +# pytest.skip( +# f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + score = torch.randn((M, E), dtype=dtype) block_n, block_k = block_size[0], block_size[1] diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 960c7f834857..416630be5220 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -5,6 +5,9 @@ import torch from vllm import _custom_ops as ops +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.utils import (_resize_cache, + _fp8_perm) #TODO make the grouped gemm kernel consistent with scaled gemm kernel @@ -173,8 +176,185 @@ def cutlass_moe_fp8( ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, expert_offsets[:-1], problem_sizes2, ab_strides2, ab_strides2, c_strides2) + # Gather tokens c2 = c2[c_map].view(m, topk, k) if not apply_router_weight_on_input: c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype) return c2.sum(dim=1) + + +class CutlassDispatch(mk.FusedMoEDispatchQuantize): + def __init__(self): + super().__init__() + + def apply( + self, + hidden_states: torch.Tensor, + hidden_states_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: + m = hidden_states.size(0) + k = w1_q.size(1) + n = w2_q.size(1) + device = hidden_states.device + + # a2_scale.numel() != 1 if a2_scale is not None else False + per_act_token = hidden_states_scale.numel() != 1 if hidden_states_scale is not None else False + + expert_offsets = torch.empty((num_experts + 1), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + ops.get_cutlass_moe_mm_data(topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + a_map, c_map, + num_experts, + n, + k) + + rep_a_q = _fp8_perm(hidden_states, a_map) + rep_a1_scales = hidden_states_scale[a_map] if per_act_token else hidden_states_scale + + return rep_a_q, rep_a1_scales, expert_offsets, c_map + + +class CutlassExperts(mk.FusedMoEExperts): + def __init__( + self, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + ): + super().__init__() + self.ab_strides1 = ab_strides1 + self.c_strides1 = c_strides1 + self.ab_strides2 = ab_strides2 + self.c_strides2 = c_strides2 + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + num_experts: int + ) -> Tuple[int, int]: + workspace1 = M * topk * N + workspace2 = M * topk * K + # return tuples???? + return (workspace1, workspace2) + + def apply( + self, + q_hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + inplace: bool, + activation: str, + expert_offsets: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: # or None? assume inplace? + # chunking in here or in ModularFusedMoEKernel? ignore for now + M = q_hidden_states.shape[0] + E, N, _ = w1.shape + _, K, _ = w2.shape + topk = X + device = q_hidden_states.device + + # fix names + c1 = _resize_cache(workspace13, (M * topk, N)) + c2 = _resize_cache(workspace13, (M * topk, K)) + c3 = _resize_cache(workspace2, (M * topk, N // 2)) + + # HACK, share these with other bits + problem_sizes1 = torch.empty((E, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((E, 3), + dtype=torch.int32, + device=E) + + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + ops.cutlass_moe_mm(c1, q_hidden_states, w1, a1_scale, w1_scale, + expert_offsets[:-1], + problem_sizes1, + self.ab_strides1, + self.ab_strides1, + self.c_strides1) + + if activation == "silu": + torch.ops._C.silu_and_mul(c3, c1) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(c3, c1) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + intemediate_q, a2_scale = ops.scaled_fp8_quant( + c3, a2_scale, use_per_token_if_dynamic=per_act_token) + + ops.cutlass_moe_mm(c2, intemediate_q, w2, a2_scale, w2_scale, + expert_offsets[:-1], problem_sizes2, self.ab_strides2, + self.ab_strides2, self.c_strides2) + + return c2 + + +class CutlassUnpermuteCombine(mk.FusedMoEUnpermuteCombine): + def __init__(self, out_dtype): + super().__init__() + self.out_dtype = out_dtype + + def apply( + self, + out: torch.Tensor, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + inv_perm: Optional[torch.Tensor], + ) -> torch.Tensor: + M, topk = topk_weights.shape + K = hidden_states.shape[1] + hidden_states = hidden_states[inv_perm, ...] + hidden_states = hidden_states.view(M, topk, K) + out = hidden_states.mul_(topk_weights.view(M, topk, 1).to(self.out_dtype)).sum(dim=1) + return out + + +def modular_cutlass_moe_fp8( + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + out_dtype, +) -> mk.ModularFusedMoEKernel: + return mk.ModularFusedMoEKernel( + CutlassDispatch(), + CutlassExperts( + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + ), + CutlassUnpermuteCombine(out_dtype), + ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 1c5212e943e6..5050e251f54a 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -19,7 +19,7 @@ has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None - +# TODO: check types? def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 5866129eccbc..fbce6dbb14cf 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -70,6 +70,8 @@ def apply( raise NotImplementedError +# Note: only intended for use with a single model layer (due to temp buffers, constants, etc.) +# TODO: permute/unpermute must be paired class ModularFusedMoEKernel(torch.nn.Module): # should this be a module? def __init__( self, From d78f1c8d53bcbb4d5b0fbae589de07cf1125aa4c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 13:49:41 +0000 Subject: [PATCH 004/171] working cutlass Signed-off-by: Bill Nell --- tests/kernels/test_cutlass_moe.py | 274 ++++++++++++++++++ .../layers/fused_moe/cutlass_moe.py | 116 ++++---- .../layers/fused_moe/deep_gemm_moe.py | 14 +- .../layers/fused_moe/modular_kernel.py | 16 +- 4 files changed, 359 insertions(+), 61 deletions(-) create mode 100644 tests/kernels/test_cutlass_moe.py diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py new file mode 100644 index 000000000000..d4b62a8c86ee --- /dev/null +++ b/tests/kernels/test_cutlass_moe.py @@ -0,0 +1,274 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8, modular_cutlass_moe_fp8 +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, + fused_topk) +from vllm.platforms import current_platform + +NUM_EXPERTS = [40, 64] +TOP_KS = [6, 8] + + +def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, + w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, c_strides2: torch.Tensor): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return cutlass_moe_fp8(a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale) + + +@pytest.mark.parametrize("m", [2, 64, 224]) +@pytest.mark.parametrize("n", [1024, 3072]) +@pytest.mark.parametrize("k", [1024, 1536]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + # Get the right scale for tests. + _, a_scale1 = ops.scaled_fp8_quant( + a, use_per_token_if_dynamic=per_act_token) + a_q, _ = ops.scaled_fp8_quant(a, + a_scale1, + use_per_token_if_dynamic=per_act_token) + + a_d = a_q.float().mul(a_scale1).to(dtype) + + n_b_scales = 2 * n if per_out_ch else 1 + k_b_scales = k if per_out_ch else 1 + + w1_q = torch.empty((e, 2 * n, k), + device="cuda", + dtype=torch.float8_e4m3fn) + w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=per_out_ch) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=per_out_ch) + w1_q = w1_q.transpose(1, 2) + w2_q = w2_q.transpose(1, 2) + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + w1_d = torch.empty_like(w1) + w2_d = torch.empty_like(w2) + for expert in range(e): + w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() + w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) + + if True: + cutlass_moe_fp8_fn = modular_cutlass_moe_fp8( + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + ) + else: + def cutlass_moe_fp8_fn( + a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + a1_scale=a_scale1 + ): + return cutlass_moe_fp8( + a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale1 + ) + + cutlass_output = cutlass_moe_fp8_fn( + a, + w1_q, + w2_q, + w1_scale=w1_scale, + w2_scale=w2_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + a1_scale=a_scale1) + + #print(triton_output) + #print(cutlass_output) + #print("*") + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=5e-2, + rtol=1e-2) + + +@pytest.mark.parametrize("m", [2, 64, 224]) +@pytest.mark.parametrize("n", [1024, 3072]) +@pytest.mark.parametrize("k", [1024, 1536]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_cuda_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + # Get the right scale for tests. + _, a_scale1 = ops.scaled_fp8_quant( + a, use_per_token_if_dynamic=per_act_token) + a_q, _ = ops.scaled_fp8_quant(a, + a_scale1, + use_per_token_if_dynamic=per_act_token) + + a_d = a_q.float().mul(a_scale1).to(dtype) + + n_b_scales = 2 * n if per_out_ch else 1 + k_b_scales = k if per_out_ch else 1 + + w1_q = torch.empty((e, 2 * n, k), + device="cuda", + dtype=torch.float8_e4m3fn) + w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=per_out_ch) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=per_out_ch) + w1_q = w1_q.transpose(1, 2) + w2_q = w2_q.transpose(1, 2) + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + w1_d = torch.empty_like(w1) + w2_d = torch.empty_like(w2) + for expert in range(e): + w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() + w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) + + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + cutlass_output = run(a, a_scale1, w1_q, w2_q, w1_scale, w2_scale, + topk_weights, topk_ids, ab_strides1, + c_strides1, ab_strides2, c_strides2) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + #print(triton_output) + #print(cutlass_output) + #print("*") + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=9e-2, + rtol=1e-2) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 416630be5220..dafe0d8c6014 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -60,7 +60,7 @@ def cutlass_moe_fp8( - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize the intermediate result between the gemms. Shape: scalar or [M] - - out_dtype (torch.Tensor): The output tensor type. + - out_dtype (torch.dtype): The output tensor type. - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, every Rank is responsible for a subset of experts. expert_map is a mapping from global expert-id to local expert-id. When expert_map[i] @@ -190,19 +190,24 @@ def __init__(self): def apply( self, - hidden_states: torch.Tensor, - hidden_states_scale: Optional[torch.Tensor], + a: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], + k: int # Try to get rid of? ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: - m = hidden_states.size(0) - k = w1_q.size(1) - n = w2_q.size(1) - device = hidden_states.device + m, n = a.shape + device = a.device # a2_scale.numel() != 1 if a2_scale is not None else False - per_act_token = hidden_states_scale.numel() != 1 if hidden_states_scale is not None else False + #per_act_token = hidden_states_scale.numel() != 1 if hidden_states_scale is not None else False + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + a_q, a1_scale = ops.scaled_fp8_quant( + a, a1_scale, use_per_token_if_dynamic=per_act_token) expert_offsets = torch.empty((num_experts + 1), dtype=torch.int32, @@ -221,15 +226,16 @@ def apply( expert_offsets, problem_sizes1, problem_sizes2, - a_map, c_map, + a_map, + c_map, num_experts, - n, - k) + k, + n) - rep_a_q = _fp8_perm(hidden_states, a_map) - rep_a1_scales = hidden_states_scale[a_map] if per_act_token else hidden_states_scale + rep_a_q = _fp8_perm(a_q, a_map) + rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale - return rep_a_q, rep_a1_scales, expert_offsets, c_map + return rep_a_q, rep_a1_scales, expert_offsets, c_map, (problem_sizes1, problem_sizes2) class CutlassExperts(mk.FusedMoEExperts): @@ -249,13 +255,13 @@ def __init__( def workspace_shapes( self, M: int, - N: int, K: int, + N: int, topk: int, num_experts: int ) -> Tuple[int, int]: - workspace1 = M * topk * N - workspace2 = M * topk * K + workspace1 = M * topk * max(2 * N, K) + workspace2 = M * topk * N # return tuples???? return (workspace1, workspace2) @@ -273,52 +279,61 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + context: Optional[Any] = None, ) -> torch.Tensor: # or None? assume inplace? # chunking in here or in ModularFusedMoEKernel? ignore for now M = q_hidden_states.shape[0] - E, N, _ = w1.shape - _, K, _ = w2.shape - topk = X - device = q_hidden_states.device + E, N, K = w2.shape # because w1 + w2 are transposed # fix names - c1 = _resize_cache(workspace13, (M * topk, N)) - c2 = _resize_cache(workspace13, (M * topk, K)) - c3 = _resize_cache(workspace2, (M * topk, N // 2)) - - # HACK, share these with other bits - problem_sizes1 = torch.empty((E, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((E, 3), - dtype=torch.int32, - device=E) + c1 = _resize_cache(workspace13, (M, N * 2)) + c2 = _resize_cache(workspace2, (M, N)) + c3 = _resize_cache(workspace13, (M, K)) + # why check a1_scale again? per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - ops.cutlass_moe_mm(c1, q_hidden_states, w1, a1_scale, w1_scale, - expert_offsets[:-1], - problem_sizes1, - self.ab_strides1, - self.ab_strides1, - self.c_strides1) + assert context is not None + problem_sizes1, problem_sizes2 = context + + ops.cutlass_moe_mm( + c1, + q_hidden_states, + w1, + a1_scale, + w1_scale, + expert_offsets[:-1], + problem_sizes1, + self.ab_strides1, + self.ab_strides1, + self.c_strides1 + ) if activation == "silu": - torch.ops._C.silu_and_mul(c3, c1) + torch.ops._C.silu_and_mul(c2, c1) elif activation == "gelu": - torch.ops._C.gelu_and_mul(c3, c1) + torch.ops._C.gelu_and_mul(c2, c1) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") intemediate_q, a2_scale = ops.scaled_fp8_quant( - c3, a2_scale, use_per_token_if_dynamic=per_act_token) + c2, a2_scale, use_per_token_if_dynamic=per_act_token) - ops.cutlass_moe_mm(c2, intemediate_q, w2, a2_scale, w2_scale, - expert_offsets[:-1], problem_sizes2, self.ab_strides2, - self.ab_strides2, self.c_strides2) + ops.cutlass_moe_mm( + c3, + intemediate_q, + w2, + a2_scale, + w2_scale, + expert_offsets[:-1], + problem_sizes2, + self.ab_strides2, + self.ab_strides2, + self.c_strides2 + ) - return c2 + return c3 class CutlassUnpermuteCombine(mk.FusedMoEUnpermuteCombine): @@ -335,10 +350,11 @@ def apply( ) -> torch.Tensor: M, topk = topk_weights.shape K = hidden_states.shape[1] - hidden_states = hidden_states[inv_perm, ...] - hidden_states = hidden_states.view(M, topk, K) - out = hidden_states.mul_(topk_weights.view(M, topk, 1).to(self.out_dtype)).sum(dim=1) - return out + hidden_states = hidden_states[inv_perm, ...].view(-1, topk, K) + hidden_states = (hidden_states * topk_weights.view(M, -1, 1).to(self.out_dtype)).sum(dim=1) + # use moe_sum? to write into out? + return hidden_states + def modular_cutlass_moe_fp8( @@ -346,7 +362,7 @@ def modular_cutlass_moe_fp8( c_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides2: torch.Tensor, - out_dtype, + out_dtype: torch.dtype = torch.half, ) -> mk.ModularFusedMoEKernel: return mk.ModularFusedMoEKernel( CutlassDispatch(), diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 5050e251f54a..2be26da9fa1c 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import torch @@ -306,10 +306,13 @@ def apply( self, hidden_states: torch.Tensor, hidden_states_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: + n: int, # TODO try to get rid of this? + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor], Optional[Any]]: + # TODO: move? q_hidden_states, q_hidden_states_scale = _fp8_quantize( hidden_states, hidden_states_scale, @@ -325,7 +328,7 @@ def apply( self.block_shape[0], ) - return q_hidden_states, q_hidden_states_scale, expert_ids, inv_perm + return q_hidden_states, q_hidden_states_scale, expert_ids, inv_perm, None class DeepGemmExperts(mk.FusedMoEExperts): @@ -346,8 +349,8 @@ def workspace_shapes( block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) - workspace1 = M_sum * max(N, K) - workspace2 = M_sum * (N // 2) + workspace1 = M_sum * max(N * 2, K) + workspace2 = M_sum * N # return tuples???? return (workspace1, workspace2) @@ -365,6 +368,7 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + context: Optional[Any] = None, ) -> torch.Tensor: # or None? assume inplace? import deep_gemm as dg diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index fbce6dbb14cf..ed358273bb49 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import torch @@ -12,11 +12,13 @@ def apply( self, hidden_states: torch.Tensor, hidden_states_scale: Optional[torch.Tensor], + a2: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: - # returns (hidden_states, scales, expert_ids, inv_perm) # make more abstract? + n: int, # TODO try to get rid of this? + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor], Optional[Any]]: + # returns (hidden_states, scales, expert_ids, inv_perm, context) # make more abstract? raise NotImplementedError @@ -103,8 +105,7 @@ def forward( a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: M, _ = hidden_states.shape - E, N, _ = w1.shape - K = w2.shape[1] + E, K, N = w2.shape if global_num_experts == -1: global_num_experts = E top_k = topk_ids.shape[1] @@ -129,12 +130,14 @@ def forward( #print(f"\nbefore M = {hidden_states.shape[0]}") - hidden_states, a1_scale, expert_ids, inv_perm = self.dispatch.apply( + hidden_states, a1_scale, expert_ids, inv_perm, context = self.dispatch.apply( hidden_states, a1_scale, + a2_scale, topk_ids, global_num_experts, expert_map, + w2.shape[1], ) #print(f"after M = {hidden_states.shape[0]}") @@ -152,6 +155,7 @@ def forward( a2_scale=a2_scale, workspace13=workspace13, workspace2=workspace2, + context=context, ) return self.combine.apply(out_hidden_states, fused_out, topk_weights, inv_perm) From 8e3e6a9c64ed445cc2cc95d30cf2c05efdec620e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 20:33:59 +0000 Subject: [PATCH 005/171] deepgemm working again Signed-off-by: Bill Nell --- .../layers/fused_moe/cutlass_moe.py | 119 +++++++++--------- .../layers/fused_moe/deep_gemm_moe.py | 107 ++++++++-------- .../layers/fused_moe/modular_kernel.py | 101 ++++++++------- 3 files changed, 163 insertions(+), 164 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index dafe0d8c6014..9f9d24aa385c 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -184,11 +184,12 @@ def cutlass_moe_fp8( return c2.sum(dim=1) -class CutlassDispatch(mk.FusedMoEDispatchQuantize): - def __init__(self): +class CutlassDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + def __init__(self, out_dtype: torch.dtype): super().__init__() + self.out_dtype = out_dtype - def apply( + def dispatch( self, a: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -196,31 +197,27 @@ def apply( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - k: int # Try to get rid of? - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: - m, n = a.shape - device = a.device - - # a2_scale.numel() != 1 if a2_scale is not None else False - #per_act_token = hidden_states_scale.numel() != 1 if hidden_states_scale is not None else False + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + # why do we need to check a2_scale here? per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) a_q, a1_scale = ops.scaled_fp8_quant( a, a1_scale, use_per_token_if_dynamic=per_act_token) - expert_offsets = torch.empty((num_experts + 1), - dtype=torch.int32, - device=device) - problem_sizes1 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) + return a_q, a1_scale, topk_ids - a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + def combine( + self, + out: torch.Tensor, #TBD + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + ) -> torch.Tensor: + M, topk = topk_weights.shape + K = hidden_states.shape[1] + hidden_states = (hidden_states.view(-1, topk, K) * topk_weights.view(M, -1, 1).to(self.out_dtype)).sum(dim=1) + # use moe_sum? to write into out? + return hidden_states ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, @@ -238,7 +235,7 @@ def apply( return rep_a_q, rep_a1_scales, expert_offsets, c_map, (problem_sizes1, problem_sizes2) -class CutlassExperts(mk.FusedMoEExperts): +class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, ab_strides1: torch.Tensor, @@ -267,36 +264,64 @@ def workspace_shapes( def apply( self, + out: torch.Tensor, # TBD q_hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_ids: torch.Tensor, inplace: bool, activation: str, - expert_offsets: torch.Tensor, + expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - context: Optional[Any] = None, ) -> torch.Tensor: # or None? assume inplace? # chunking in here or in ModularFusedMoEKernel? ignore for now M = q_hidden_states.shape[0] - E, N, K = w2.shape # because w1 + w2 are transposed + E, N, _ = w2.shape # because w1 + w2 are transposed + K = w1.shape[1] #? + assert K == w2.shape[-1] + device = q_hidden_states.device + + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + expert_offsets = torch.empty((E + 1), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((E, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((E, 3), + dtype=torch.int32, + device=device) + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + #print(f"prob {k}, {n}") + + ops.get_cutlass_moe_mm_data(topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + a_map, + c_map, + E, + N, + K) + + q_hidden_states = _fp8_perm(q_hidden_states, a_map) + a1_scale = a1_scale[a_map] if per_act_token else a1_scale # fix names c1 = _resize_cache(workspace13, (M, N * 2)) c2 = _resize_cache(workspace2, (M, N)) c3 = _resize_cache(workspace13, (M, K)) - # why check a1_scale again? - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) - - assert context is not None - problem_sizes1, problem_sizes2 = context - ops.cutlass_moe_mm( c1, q_hidden_states, @@ -333,28 +358,9 @@ def apply( self.c_strides2 ) - return c3 - - -class CutlassUnpermuteCombine(mk.FusedMoEUnpermuteCombine): - def __init__(self, out_dtype): - super().__init__() - self.out_dtype = out_dtype - - def apply( - self, - out: torch.Tensor, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - inv_perm: Optional[torch.Tensor], - ) -> torch.Tensor: - M, topk = topk_weights.shape - K = hidden_states.shape[1] - hidden_states = hidden_states[inv_perm, ...].view(-1, topk, K) - hidden_states = (hidden_states * topk_weights.view(M, -1, 1).to(self.out_dtype)).sum(dim=1) - # use moe_sum? to write into out? - return hidden_states + c3 = c3[c_map, ...] + return c3 def modular_cutlass_moe_fp8( @@ -363,14 +369,13 @@ def modular_cutlass_moe_fp8( ab_strides2: torch.Tensor, c_strides2: torch.Tensor, out_dtype: torch.dtype = torch.half, -) -> mk.ModularFusedMoEKernel: - return mk.ModularFusedMoEKernel( - CutlassDispatch(), +) -> mk.FusedMoEModularKernel: + return mk.FusedMoEModularKernel( + CutlassDispatchCombine(out_dtype), CutlassExperts( ab_strides1, c_strides1, ab_strides2, c_strides2, ), - CutlassUnpermuteCombine(out_dtype), ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 2be26da9fa1c..3b39e45c7ea4 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple import torch @@ -19,6 +19,13 @@ has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None + +def deep_gemm_block_shape() -> List[int]: + import deep_gemm as dg + block = dg.get_m_alignment_for_contiguous_layout() + return [block, block] + + # TODO: check types? def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, @@ -109,7 +116,8 @@ def _moe_unpermute_and_reduce( """ M, topk = topk_weight.shape K = curr_hidden.shape[1] - curr_hidden = curr_hidden[inv_perm, ...] + if inv_perm is not None: + curr_hidden = curr_hidden[inv_perm, ...] curr_hidden = curr_hidden.view(-1, topk, K) curr_hidden.mul_(topk_weight.view(M, -1, 1)) ops.moe_sum(curr_hidden, out) @@ -295,48 +303,46 @@ def deep_gemm_moe_fp8( return out_hidden_states -class DeepGemmDispatch(mk.FusedMoEDispatchQuantize): +class DeepGemmDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): def __init__(self): super().__init__() - import deep_gemm as dg - block_m = dg.get_m_alignment_for_contiguous_layout() - self.block_shape = [block_m, block_m] + self.block_shape = deep_gemm_block_shape() - def apply( + def dispatch( self, - hidden_states: torch.Tensor, - hidden_states_scale: Optional[torch.Tensor], + a: torch.Tensor, + a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - n: int, # TODO try to get rid of this? - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor], Optional[Any]]: - # TODO: move? + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: q_hidden_states, q_hidden_states_scale = _fp8_quantize( - hidden_states, - hidden_states_scale, + a, + a1_scale, self.block_shape, ) + return q_hidden_states, q_hidden_states_scale, topk_ids - q_hidden_states, q_hidden_states_scale, _, expert_ids, inv_perm = _moe_permute( - q_hidden_states, - q_hidden_states_scale, - topk_ids, - num_experts, - expert_map, - self.block_shape[0], + def combine( + self, + out: torch.Tensor, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + ) -> torch.Tensor: + _moe_unpermute_and_reduce( + out, + hidden_states, + None, + topk_weights ) - - return q_hidden_states, q_hidden_states_scale, expert_ids, inv_perm, None + return out -class DeepGemmExperts(mk.FusedMoEExperts): +class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self): super().__init__() - import deep_gemm as dg - block_m = dg.get_m_alignment_for_contiguous_layout() - self.block_shape = [block_m, block_m] + self.block_shape = deep_gemm_block_shape() def workspace_shapes( self, @@ -352,33 +358,43 @@ def workspace_shapes( workspace1 = M_sum * max(N * 2, K) workspace2 = M_sum * N # return tuples???? - return (workspace1, workspace2) + return (workspace1, workspace2) # TODO add type def apply( self, + out: torch.Tensor, #unused tbd q_hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_ids: torch.Tensor, inplace: bool, activation: str, - expert_ids: torch.Tensor, + expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - context: Optional[Any] = None, ) -> torch.Tensor: # or None? assume inplace? import deep_gemm as dg # chunking in here or in ModularFusedMoEKernel? ignore for now - M_sum = q_hidden_states.shape[0] # double check this E, N, _ = w1.shape _, K, _ = w2.shape #print(f"M_sum = {M_sum}") + q_hidden_states, a1_scale, _, expert_ids, inv_perm = _moe_permute( + q_hidden_states, + a1_scale, + topk_ids, + E, + expert_map, + self.block_shape[0], + ) + + M_sum = q_hidden_states.shape[0] workspace1 = _resize_cache(workspace13, (M_sum, N)) workspace2 = _resize_cache(workspace2, (M_sum, N // 2)) workspace3 = _resize_cache(workspace13, (M_sum, K)) @@ -406,32 +422,13 @@ def apply( (qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids) - return workspace3 - - -class DeepGemmUnpermuteCombine(mk.FusedMoEUnpermuteCombine): - def __init__(self): - super().__init__() + workspace3 = workspace3[inv_perm, ...] - def apply( - self, - out: torch.Tensor, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - inv_perm: Optional[torch.Tensor], - ) -> torch.Tensor: - _moe_unpermute_and_reduce( - out, - hidden_states, - inv_perm, - topk_weights - ) - return out + return workspace3 -def modular_deep_gemm_fused_moe_fp8() -> mk.ModularFusedMoEKernel: - return mk.ModularFusedMoEKernel( - DeepGemmDispatch(), +def modular_deep_gemm_fused_moe_fp8() -> mk.FusedMoEModularKernel: + return mk.FusedMoEModularKernel( + DeepGemmDispatchCombine(), DeepGemmExperts(), - DeepGemmUnpermuteCombine(), ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index ed358273bb49..cef11efe22a3 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -3,27 +3,36 @@ import torch -class FusedMoEDispatchQuantize(ABC): +class FusedMoEQuantizeDispatchCombine(ABC): def __init__(self): pass @abstractmethod - def apply( + def dispatch( self, - hidden_states: torch.Tensor, - hidden_states_scale: Optional[torch.Tensor], - a2: Optional[torch.Tensor], + a: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - n: int, # TODO try to get rid of this? - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor], Optional[Any]]: - # returns (hidden_states, scales, expert_ids, inv_perm, context) # make more abstract? + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + # TODO: figure this out + # returns (quantized+dispatched hidden_states, quantized+dispatched scales, dispatched topk_ids) + raise NotImplementedError + + @abstractmethod + def combine( + self, + out: torch.Tensor, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + ) -> torch.Tensor: raise NotImplementedError # store weights, etc. here -class FusedMoEExperts(ABC): +class FusedMoEPermuteExpertsUnpermute(ABC): def __init__(self): pass @@ -45,46 +54,31 @@ def apply( q_hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool, activation: str, - expert_ids: torch.Tensor, + expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], - q_hidden_states_scale: Optional[torch.Tensor], - hidden_states_scale_2: Optional[torch.Tensor], - workspace1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, workspace2: torch.Tensor, ) -> torch.Tensor: # or None? assume inplace? raise NotImplementedError -class FusedMoEUnpermuteCombine(ABC): - def __init__(self): - pass - - @abstractmethod - def apply( - self, - out: torch.Tensor, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - inv_perm: Optional[torch.Tensor], - ) -> torch.Tensor: - raise NotImplementedError - - # Note: only intended for use with a single model layer (due to temp buffers, constants, etc.) # TODO: permute/unpermute must be paired -class ModularFusedMoEKernel(torch.nn.Module): # should this be a module? +class FusedMoEModularKernel(torch.nn.Module): # should this be a module? def __init__( self, - dispatch: FusedMoEDispatchQuantize, - fused_experts: FusedMoEExperts, - combine: FusedMoEUnpermuteCombine, + dispatch_combine: FusedMoEQuantizeDispatchCombine, + fused_experts: FusedMoEPermuteExpertsUnpermute, ): super().__init__() - self.dispatch = dispatch + self.dispatch_combine = dispatch_combine self.fused_experts = fused_experts - self.combine = combine def forward( self, @@ -110,14 +104,17 @@ def forward( global_num_experts = E top_k = topk_ids.shape[1] - if inplace: + if False and inplace: out_hidden_states = hidden_states else: out_hidden_states = torch.empty_like(hidden_states) #print(f"TKN = {topk_ids.numel()} {M*top_k}") - workspace13_shape, workspace2_shape = self.fused_experts.workspace_shapes(M, N, K, top_k, global_num_experts) + workspace13_shape, workspace2_shape = ( + self.fused_experts.workspace_shapes( + M, N, K, top_k, global_num_experts) + ) # We can reuse the memory between cache1 and cache3 because by the time # we need cache3, we're done with cache1 @@ -130,32 +127,32 @@ def forward( #print(f"\nbefore M = {hidden_states.shape[0]}") - hidden_states, a1_scale, expert_ids, inv_perm, context = self.dispatch.apply( - hidden_states, - a1_scale, - a2_scale, - topk_ids, - global_num_experts, - expert_map, - w2.shape[1], + hidden_states, a1_scale, new_topk_ids = self.dispatch_combine.dispatch( + a=hidden_states, + a1_scale=a1_scale, + a2_scale=a2_scale, + topk_ids=topk_ids, + num_experts=global_num_experts, + expert_map=expert_map, ) #print(f"after M = {hidden_states.shape[0]}") fused_out = self.fused_experts.apply( - hidden_states, - w1, - w2, - inplace, - activation, - expert_ids, + out=hidden_states, + q_hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_ids=new_topk_ids, + inplace=inplace, + activation=activation, + expert_map=expert_map, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, workspace13=workspace13, workspace2=workspace2, - context=context, ) - return self.combine.apply(out_hidden_states, fused_out, topk_weights, inv_perm) + return self.dispatch_combine.combine(out_hidden_states, fused_out, topk_weights) From e41e4bf84bb654741c421ad365cc15f8e7f9fbf5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 20:36:32 +0000 Subject: [PATCH 006/171] cutlass working again Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 9f9d24aa385c..4ebf48d026cc 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -283,7 +283,9 @@ def apply( M = q_hidden_states.shape[0] E, N, _ = w2.shape # because w1 + w2 are transposed K = w1.shape[1] #? + topk = topk_ids.shape[1] assert K == w2.shape[-1] + assert E == w1.shape[0] device = q_hidden_states.device per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( @@ -318,9 +320,9 @@ def apply( a1_scale = a1_scale[a_map] if per_act_token else a1_scale # fix names - c1 = _resize_cache(workspace13, (M, N * 2)) - c2 = _resize_cache(workspace2, (M, N)) - c3 = _resize_cache(workspace13, (M, K)) + c1 = _resize_cache(workspace13, (M * topk, N * 2)) + c2 = _resize_cache(workspace2, (M * topk, N)) + c3 = _resize_cache(workspace13, (M * topk, K)) ops.cutlass_moe_mm( c1, From 0b877baa9313ae9a892253033fe5962904e19bf1 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 20:37:13 +0000 Subject: [PATCH 007/171] cutlass working again Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/modular_kernel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index cef11efe22a3..c780a494f4e3 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -104,7 +104,9 @@ def forward( global_num_experts = E top_k = topk_ids.shape[1] - if False and inplace: + assert not inplace, "NYI" + + if inplace: out_hidden_states = hidden_states else: out_hidden_states = torch.empty_like(hidden_states) From 5d76ee93a195c0b3e5f422b5917751407ee8181d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 22:14:54 +0000 Subject: [PATCH 008/171] fix inplace, format and name cleanups Signed-off-by: Bill Nell --- tests/kernels/quantization/test_block_fp8.py | 21 +- tests/kernels/test_cutlass_moe.py | 132 +++++++----- .../layers/fused_moe/cutlass_moe.py | 195 ++++++++---------- .../layers/fused_moe/deep_gemm_moe.py | 124 +++++------ .../layers/fused_moe/fused_moe.py | 4 - .../layers/fused_moe/modular_kernel.py | 186 ++++++++--------- 6 files changed, 319 insertions(+), 343 deletions(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index aa1e5b52073b..3e372b007aac 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -404,9 +404,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) -# if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): -# pytest.skip( -# f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + # if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): + # pytest.skip( + # f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") score = torch.randn((M, E), dtype=dtype) @@ -434,8 +434,16 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): if True: dgm = modular_deep_gemm_fused_moe_fp8() - def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): - return dgm(a, w1, w2, topk_weights, topk_ids, w1_scale=w1_s, w2_scale=w2_s) + + def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids): + return dgm(a, + w1, + w2, + topk_weights, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s) else: deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 @@ -450,7 +458,8 @@ def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index d4b62a8c86ee..0dc572c72885 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Callable, Optional + import pytest import torch from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8, modular_cutlass_moe_fp8 +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8, modular_cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, fused_topk) from vllm.platforms import current_platform @@ -13,6 +16,48 @@ TOP_KS = [6, 8] +def get_cutlass_moe_fp8(ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + out_dtype=torch.half) -> Callable: + if True: + return modular_cutlass_moe_fp8( + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + out_dtype, + ) + else: + + def cutlass_moe_fp8_fn( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + a1_scale: Optional[torch.Tensor], + ) -> torch.Tensor: + return cutlass_moe_fp8(a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale, + out_dtype=out_dtype) + + return cutlass_moe_fp8_fn + + def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, @@ -21,18 +66,22 @@ def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - return cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) + + cutlass_moe_fp8_fn = get_cutlass_moe_fp8( + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + ) + + return cutlass_moe_fp8_fn(a, + w1_q, + w2_q, + w1_scale=w1_scale, + w2_scale=w2_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + a1_scale=a_scale) @pytest.mark.parametrize("m", [2, 64, 224]) @@ -118,48 +167,21 @@ def test_cutlass_moe_no_graph( triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) - if True: - cutlass_moe_fp8_fn = modular_cutlass_moe_fp8( - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - ) - else: - def cutlass_moe_fp8_fn( - a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - a1_scale=a_scale1 - ): - return cutlass_moe_fp8( - a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale1 - ) - - cutlass_output = cutlass_moe_fp8_fn( - a, - w1_q, - w2_q, - w1_scale=w1_scale, - w2_scale=w2_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - a1_scale=a_scale1) + cutlass_moe_fp8_fn = get_cutlass_moe_fp8( + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + ) + + cutlass_output = cutlass_moe_fp8_fn(a, + w1_q, + w2_q, + w1_scale=w1_scale, + w2_scale=w2_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + a1_scale=a_scale1) #print(triton_output) #print(cutlass_output) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 4ebf48d026cc..68783501a72c 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" -from typing import Optional +from typing import Optional, Tuple import torch -from vllm import _custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.utils import (_resize_cache, - _fp8_perm) +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache #TODO make the grouped gemm kernel consistent with scaled gemm kernel @@ -185,39 +184,42 @@ def cutlass_moe_fp8( class CutlassDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + def __init__(self, out_dtype: torch.dtype): super().__init__() self.out_dtype = out_dtype def dispatch( - self, - a: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: # why do we need to check a2_scale here? per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - a_q, a1_scale = ops.scaled_fp8_quant( - a, a1_scale, use_per_token_if_dynamic=per_act_token) + a1q, a1q_scale = ops.scaled_fp8_quant( + a1, a1_scale, use_per_token_if_dynamic=per_act_token) - return a_q, a1_scale, topk_ids + return a1q, a1_scale, topk_ids def combine( - self, - out: torch.Tensor, #TBD - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - ) -> torch.Tensor: + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + ) -> None: M, topk = topk_weights.shape - K = hidden_states.shape[1] - hidden_states = (hidden_states.view(-1, topk, K) * topk_weights.view(M, -1, 1).to(self.out_dtype)).sum(dim=1) - # use moe_sum? to write into out? - return hidden_states + K = fused_expert_output.shape[1] + fused_expert_output = fused_expert_output.view( + -1, topk, K) * topk_weights.view( + M, -1, 1) #.to(self.out_dtype)).sum(dim=1) + assert output.dtype == self.out_dtype + ops.moe_sum(fused_expert_output, output) ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, @@ -236,106 +238,85 @@ def combine( class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__( - self, - ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, + self, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + out_dtype: torch.dtype, ): super().__init__() self.ab_strides1 = ab_strides1 self.c_strides1 = c_strides1 self.ab_strides2 = ab_strides2 self.c_strides2 = c_strides2 + self.out_dtype = out_dtype def workspace_shapes( self, M: int, - K: int, + K: int, # Note that K, N are transposed N: int, topk: int, - num_experts: int - ) -> Tuple[int, int]: + num_experts: int) -> Tuple[int, int, torch.dtype]: workspace1 = M * topk * max(2 * N, K) workspace2 = M * topk * N - # return tuples???? - return (workspace1, workspace2) + return (workspace1, workspace2, self.out_dtype) def apply( - self, - out: torch.Tensor, # TBD - q_hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool, - activation: str, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - ) -> torch.Tensor: # or None? assume inplace? - # chunking in here or in ModularFusedMoEKernel? ignore for now - M = q_hidden_states.shape[0] - E, N, _ = w2.shape # because w1 + w2 are transposed - K = w1.shape[1] #? + self, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: + # TODO: chunking in here or in FusedMoEModularKernel? ignore for now + M = a1q.shape[0] + E, N, K = w2.shape # because w1 + w2 are transposed topk = topk_ids.shape[1] - assert K == w2.shape[-1] - assert E == w1.shape[0] - device = q_hidden_states.device + device = a1q.device - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + assert w1.shape[1] == K + assert w1.shape[0] == E - expert_offsets = torch.empty((E + 1), - dtype=torch.int32, - device=device) - problem_sizes1 = torch.empty((E, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((E, 3), - dtype=torch.int32, - device=device) + per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) - a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + expert_offsets = torch.empty((E + 1), dtype=torch.int32, device=device) + problem_sizes1 = torch.empty((E, 3), dtype=torch.int32, device=device) + problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device=device) - #print(f"prob {k}, {n}") + a_map = torch.empty((topk_ids.numel()), + dtype=torch.int32, + device=device) + c_map = torch.empty((topk_ids.numel()), + dtype=torch.int32, + device=device) - ops.get_cutlass_moe_mm_data(topk_ids, - expert_offsets, - problem_sizes1, - problem_sizes2, - a_map, - c_map, - E, - N, - K) + ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, a_map, c_map, E, N, K) - q_hidden_states = _fp8_perm(q_hidden_states, a_map) - a1_scale = a1_scale[a_map] if per_act_token else a1_scale + a1q = _fp8_perm(a1q, a_map) + a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale # fix names c1 = _resize_cache(workspace13, (M * topk, N * 2)) c2 = _resize_cache(workspace2, (M * topk, N)) c3 = _resize_cache(workspace13, (M * topk, K)) - ops.cutlass_moe_mm( - c1, - q_hidden_states, - w1, - a1_scale, - w1_scale, - expert_offsets[:-1], - problem_sizes1, - self.ab_strides1, - self.ab_strides1, - self.c_strides1 - ) + ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, + expert_offsets[:-1], problem_sizes1, + self.ab_strides1, self.ab_strides1, self.c_strides1) if activation == "silu": torch.ops._C.silu_and_mul(c2, c1) @@ -344,21 +325,12 @@ def apply( else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") - intemediate_q, a2_scale = ops.scaled_fp8_quant( + a2q, a2q_scale = ops.scaled_fp8_quant( c2, a2_scale, use_per_token_if_dynamic=per_act_token) - ops.cutlass_moe_mm( - c3, - intemediate_q, - w2, - a2_scale, - w2_scale, - expert_offsets[:-1], - problem_sizes2, - self.ab_strides2, - self.ab_strides2, - self.c_strides2 - ) + ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, + expert_offsets[:-1], problem_sizes2, + self.ab_strides2, self.ab_strides2, self.c_strides2) c3 = c3[c_map, ...] @@ -366,11 +338,11 @@ def apply( def modular_cutlass_moe_fp8( - ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, - out_dtype: torch.dtype = torch.half, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + out_dtype: torch.dtype = torch.half, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( CutlassDispatchCombine(out_dtype), @@ -379,5 +351,6 @@ def modular_cutlass_moe_fp8( c_strides1, ab_strides2, c_strides2, + out_dtype, ), ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 3b39e45c7ea4..f170e4a02d58 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import Any, List, Optional, Tuple +from typing import List, Optional, Tuple import torch import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger -import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, @@ -304,123 +304,109 @@ def deep_gemm_moe_fp8( class DeepGemmDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + def __init__(self): super().__init__() self.block_shape = deep_gemm_block_shape() def dispatch( - self, - a: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: - q_hidden_states, q_hidden_states_scale = _fp8_quantize( - a, + a1q, a1q_scale = _fp8_quantize( + a1, a1_scale, self.block_shape, ) - return q_hidden_states, q_hidden_states_scale, topk_ids + return a1q, a1q_scale, topk_ids def combine( - self, - out: torch.Tensor, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - ) -> torch.Tensor: - _moe_unpermute_and_reduce( - out, - hidden_states, - None, - topk_weights - ) - return out + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + ) -> None: + _moe_unpermute_and_reduce(output, fused_expert_output, None, + topk_weights) class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__(self): super().__init__() self.block_shape = deep_gemm_block_shape() + self.out_dtype = torch.bfloat16 - def workspace_shapes( - self, - M: int, - N: int, - K: int, - topk: int, - num_experts: int - ) -> Tuple[int, int]: + def workspace_shapes(self, M: int, N: int, K: int, topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) workspace1 = M_sum * max(N * 2, K) workspace2 = M_sum * N - # return tuples???? - return (workspace1, workspace2) # TODO add type + return (workspace1, workspace2, self.out_dtype) def apply( - self, - out: torch.Tensor, #unused tbd - q_hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool, - activation: str, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - ) -> torch.Tensor: # or None? assume inplace? + self, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: import deep_gemm as dg - # chunking in here or in ModularFusedMoEKernel? ignore for now - E, N, _ = w1.shape - _, K, _ = w2.shape + # TODO: chunking in here or in FusedMoEModularKernel? ignore for now + #E, N, _ = w1.shape + #_, K, _ = w2.shape + E, N, K = w1.shape - #print(f"M_sum = {M_sum}") + assert w2.shape[1] == K + assert w2.shape[0] == E - q_hidden_states, a1_scale, _, expert_ids, inv_perm = _moe_permute( - q_hidden_states, - a1_scale, + a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute( + a1q, + a1q_scale, topk_ids, E, expert_map, self.block_shape[0], ) - M_sum = q_hidden_states.shape[0] + # Note: M_sum is different than the pre-permuted shape of a1q. + M_sum = a1q.shape[0] workspace1 = _resize_cache(workspace13, (M_sum, N)) workspace2 = _resize_cache(workspace2, (M_sum, N // 2)) workspace3 = _resize_cache(workspace13, (M_sum, K)) dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (q_hidden_states, a1_scale), (w1, w1_scale), - workspace1, - expert_ids) + (a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids) if activation == "silu": - torch.ops._C.silu_and_mul(workspace2, - workspace1.view(-1, N)) + torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N)) elif activation == "gelu": - torch.ops._C.gelu_and_mul(workspace2, - workspace1.view(-1, N)) + torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N)) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") a2q_scale: Optional[torch.Tensor] = None - qworkspace2, a2q_scale = _fp8_quantize( - workspace2, a2_scale, self.block_shape) + a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, self.block_shape) dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (qworkspace2, a2q_scale), (w2, w2_scale), - workspace3, expert_ids) + (a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids) workspace3 = workspace3[inv_perm, ...] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e00878276942..a209715ede77 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1420,10 +1420,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - if True: - intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K) - intermediate_cache3.mul_(curr_topk_weights.view(tokens_in_chunk, -1, 1)) - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index c780a494f4e3..ce08d984c3aa 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1,160 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Any, Optional, Tuple +from typing import Optional, Tuple + import torch +# TODO: add comments + class FusedMoEQuantizeDispatchCombine(ABC): + def __init__(self): pass @abstractmethod def dispatch( - self, - a: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: # TODO: figure this out - # returns (quantized+dispatched hidden_states, quantized+dispatched scales, dispatched topk_ids) + # returns (quantized+dispatched a, quantized+dispatched a1_scales, dispatched topk_ids) raise NotImplementedError @abstractmethod def combine( - self, - out: torch.Tensor, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - ) -> torch.Tensor: + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, # not reduced or weighted + topk_weights: torch.Tensor, + ) -> None: raise NotImplementedError # store weights, etc. here class FusedMoEPermuteExpertsUnpermute(ABC): + def __init__(self): pass @abstractmethod - def workspace_shapes( - self, - M: int, - N: int, - K: int, - topk: int, - num_experts: int - ) -> Tuple[int, int]: + def workspace_shapes(self, M: int, N: int, K: int, topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: raise NotImplementedError @abstractmethod def apply( - self, - out: torch.Tensor, - q_hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool, - activation: str, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - ) -> torch.Tensor: # or None? assume inplace? + self, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: raise NotImplementedError # Note: only intended for use with a single model layer (due to temp buffers, constants, etc.) # TODO: permute/unpermute must be paired -class FusedMoEModularKernel(torch.nn.Module): # should this be a module? +class FusedMoEModularKernel(torch.nn.Module): # should this be a module? + def __init__( - self, - dispatch_combine: FusedMoEQuantizeDispatchCombine, - fused_experts: FusedMoEPermuteExpertsUnpermute, + self, + dispatch_combine: FusedMoEQuantizeDispatchCombine, + fused_experts: FusedMoEPermuteExpertsUnpermute, ): super().__init__() self.dispatch_combine = dispatch_combine self.fused_experts = fused_experts def forward( - self, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, + self, + a1: torch.Tensor, # aka hidden states + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - M, _ = hidden_states.shape + M, _ = a1.shape E, K, N = w2.shape if global_num_experts == -1: global_num_experts = E top_k = topk_ids.shape[1] - assert not inplace, "NYI" - if inplace: - out_hidden_states = hidden_states + output = a1 else: - out_hidden_states = torch.empty_like(hidden_states) - - #print(f"TKN = {topk_ids.numel()} {M*top_k}") + output = torch.empty_like(a1) - workspace13_shape, workspace2_shape = ( - self.fused_experts.workspace_shapes( - M, N, K, top_k, global_num_experts) - ) + workspace13_shape, workspace2_shape, workspace_dtype = ( + self.fused_experts.workspace_shapes(M, N, K, top_k, + global_num_experts)) # We can reuse the memory between cache1 and cache3 because by the time # we need cache3, we're done with cache1 workspace13 = torch.empty(workspace13_shape, - device=hidden_states.device, - dtype=hidden_states.dtype) + device=a1.device, + dtype=workspace_dtype) workspace2 = torch.empty(workspace2_shape, - device=hidden_states.device, - dtype=hidden_states.dtype) - - #print(f"\nbefore M = {hidden_states.shape[0]}") - - hidden_states, a1_scale, new_topk_ids = self.dispatch_combine.dispatch( - a=hidden_states, - a1_scale=a1_scale, - a2_scale=a2_scale, - topk_ids=topk_ids, - num_experts=global_num_experts, - expert_map=expert_map, + device=a1.device, + dtype=workspace_dtype) + + a1q, a1q_scale, dispatched_topk_ids = self.dispatch_combine.dispatch( + a1, + a1_scale, + a2_scale, + topk_ids, + global_num_experts, + expert_map, ) - #print(f"after M = {hidden_states.shape[0]}") - fused_out = self.fused_experts.apply( - out=hidden_states, - q_hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_ids=new_topk_ids, - inplace=inplace, - activation=activation, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + a1q, + w1, + w2, + dispatched_topk_ids, + activation, + expert_map, + w1_scale, + w2_scale, + a1q_scale, + a2_scale, workspace13=workspace13, workspace2=workspace2, ) - return self.dispatch_combine.combine(out_hidden_states, fused_out, topk_weights) + self.dispatch_combine.combine(output, fused_out, topk_weights) + + return output From c2ce01a2258416f5d93f0575cac73c5f3c007f27 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 22:20:58 +0000 Subject: [PATCH 009/171] fix inplace, format + name cleanups Signed-off-by: Bill Nell --- .../layers/fused_moe/modular_kernel.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index ce08d984c3aa..3bef7ee30d16 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -9,8 +9,8 @@ class FusedMoEQuantizeDispatchCombine(ABC): - def __init__(self): - pass + # def __init__(self): + # pass @abstractmethod def dispatch( @@ -23,7 +23,9 @@ def dispatch( expert_map: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: # TODO: figure this out - # returns (quantized+dispatched a, quantized+dispatched a1_scales, dispatched topk_ids) + # returns (quantized+dispatched a, + # quantized+dispatched a1_scales, + # dispatched topk_ids) raise NotImplementedError @abstractmethod @@ -39,8 +41,8 @@ def combine( # store weights, etc. here class FusedMoEPermuteExpertsUnpermute(ABC): - def __init__(self): - pass + # def __init__(self): + # pass @abstractmethod def workspace_shapes(self, M: int, N: int, K: int, topk: int, @@ -66,8 +68,8 @@ def apply( raise NotImplementedError -# Note: only intended for use with a single model layer (due to temp buffers, constants, etc.) -# TODO: permute/unpermute must be paired +# Note: only intended for use with a single model layer (due to temp buffers, +# constants, etc.) class FusedMoEModularKernel(torch.nn.Module): # should this be a module? def __init__( @@ -103,10 +105,7 @@ def forward( global_num_experts = E top_k = topk_ids.shape[1] - if inplace: - output = a1 - else: - output = torch.empty_like(a1) + output = a1 if inplace else torch.empty_like(a1) workspace13_shape, workspace2_shape, workspace_dtype = ( self.fused_experts.workspace_shapes(M, N, K, top_k, From b52b50d8d9bdc3a3b8ac61cf50d160ce549568e2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 01:18:53 +0000 Subject: [PATCH 010/171] test improvements Signed-off-by: Bill Nell --- tests/kernels/quantization/test_block_fp8.py | 18 +++++---------- .../layers/fused_moe/deep_gemm_moe.py | 22 ++++++++----------- .../layers/fused_moe/fused_moe.py | 5 ++++- 3 files changed, 19 insertions(+), 26 deletions(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 3e372b007aac..70404dbe49a1 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,9 +11,9 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm, deep_gemm_moe_fp8, - modular_deep_gemm_fused_moe_fp8) + modular_deep_gemm_fused_moe_fp8, + _valid_deep_gemm_shape) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -382,13 +382,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): block_size = [block_m, block_m] dtype = torch.bfloat16 - # only aligned sizes TODO: use _valid_deep_gemm here instead? - if (N % block_m != 0 or K % block_m != 0 or topk > E): - pytest.skip( - f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + if topk > E: + pytest.skip(f"Skipping test: topk={topk} > E={E}") - if False and N <= 512: - pytest.skip("Skipping N <= 512 until performance issues solved.") + if not _valid_deep_gemm_shape(M, N, K): + pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") vllm_config = VllmConfig() @@ -404,10 +402,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) - # if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): - # pytest.skip( - # f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") - score = torch.randn((M, E), dtype=dtype) block_n, block_k = block_size[0], block_size[1] diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index f170e4a02d58..943e383f3bce 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch @@ -20,12 +20,18 @@ has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None -def deep_gemm_block_shape() -> List[int]: +def deep_gemm_block_shape() -> list[int]: + # Lazy import to avoid CUDA initialization problems. import deep_gemm as dg block = dg.get_m_alignment_for_contiguous_layout() return [block, block] +def _valid_deep_gemm_shape(M: int, N: int, K: int): + align = deep_gemm_block_shape()[0] + return M >= align and N % align == 0 and K % align == 0 + + # TODO: check types? def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, @@ -39,23 +45,13 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, if not has_deep_gemm: return False - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - # Expert maps not supported yet. if expert_map is not None: return False - align = dg.get_m_alignment_for_contiguous_layout() M = hidden_states.shape[0] _, K, N = w2.shape - - # For now, disable DeepGemm for small N until better permute/unpermute - # ops are available. - if N <= 512: - return False - - if align > M or N % align != 0 or K % align != 0: + if not _valid_deep_gemm_shape(M, N, K): return False return (hidden_states.is_contiguous() and w1.is_contiguous() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a209715ede77..d6fbfdc61fc5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1131,7 +1131,10 @@ def fused_experts(hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, allow_deep_gemm: bool = False) -> torch.Tensor: - if (allow_deep_gemm and use_fp8_w8a8 + # For now, disable DeepGemm for small N (<= 512) until better + # permute/unpermute ops are available. + N = w1.shape[1] + if (allow_deep_gemm and use_fp8_w8a8 and N > 512 and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): assert apply_router_weight_on_input is False return deep_gemm_moe_fp8( From 49a9d11b9f0c83c646a75201375ed439052696ed Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 04:41:01 +0000 Subject: [PATCH 011/171] make modular triton classes, fix edge cases Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 59 ++-- .../layers/fused_moe/cutlass_moe.py | 37 +-- .../layers/fused_moe/deep_gemm_moe.py | 27 +- .../layers/fused_moe/fused_moe.py | 258 ++++++++++++++++++ .../layers/fused_moe/modular_kernel.py | 32 ++- 5 files changed, 355 insertions(+), 58 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 425f36984a33..cbc20a57cf19 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -11,8 +11,10 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, - torch_moe_single) +from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, + torch_moe, torch_moe_single) +from vllm import _custom_ops as ops +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( @@ -67,31 +69,34 @@ def test_fused_moe( else: e_map = None - torch_output = torch_moe(a, w1, w2, score, topk, e_map) - iterative_output = iterative_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w1, w2, score, topk, e_map) + iterative_output = iterative_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) + + # Pad the weight if moe padding is enabled + if padding: + w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) - # Pad the weight if moe padding is enabled - if padding: - w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(iterative_output, torch_output, @@ -112,7 +117,7 @@ def test_fused_moe( def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ep_size: int, dtype: torch.dtype, group_size: int, has_zp: bool, weight_bits: int): - print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) + #print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 68783501a72c..75d25f418d1e 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -214,10 +214,9 @@ def combine( topk_weights: torch.Tensor, ) -> None: M, topk = topk_weights.shape - K = fused_expert_output.shape[1] - fused_expert_output = fused_expert_output.view( - -1, topk, K) * topk_weights.view( - M, -1, 1) #.to(self.out_dtype)).sum(dim=1) + K = fused_expert_output.shape[-1] + fused_expert_output = (fused_expert_output.view(-1, topk, K) * + topk_weights.view(M, -1, 1)) assert output.dtype == self.out_dtype ops.moe_sum(fused_expert_output, output) @@ -255,12 +254,14 @@ def __init__( self.out_dtype = out_dtype def workspace_shapes( - self, - M: int, - K: int, # Note that K, N are transposed - N: int, - topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + self, + a_dtype: torch.dtype, + M: int, + K: int, # Note that K, N are transposed + N: int, + topk: int, + num_experts: int + ) -> Tuple[int, int, torch.dtype]: workspace1 = M * topk * max(2 * N, K) workspace2 = M * topk * N return (workspace1, workspace2, self.out_dtype) @@ -272,9 +273,12 @@ def apply( w2: torch.Tensor, topk_ids: torch.Tensor, activation: str, + global_num_experts: int, expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -282,19 +286,19 @@ def apply( ) -> torch.Tensor: # TODO: chunking in here or in FusedMoEModularKernel? ignore for now M = a1q.shape[0] - E, N, K = w2.shape # because w1 + w2 are transposed + _, N, K = w2.shape # because w1 + w2 are transposed topk = topk_ids.shape[1] device = a1q.device assert w1.shape[1] == K - assert w1.shape[0] == E + assert global_num_experts != -1 per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - expert_offsets = torch.empty((E + 1), dtype=torch.int32, device=device) - problem_sizes1 = torch.empty((E, 3), dtype=torch.int32, device=device) - problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device=device) + expert_offsets = torch.empty((global_num_experts + 1), dtype=torch.int32, device=device) + problem_sizes1 = torch.empty((global_num_experts, 3), dtype=torch.int32, device=device) + problem_sizes2 = torch.empty((global_num_experts, 3), dtype=torch.int32, device=device) a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, @@ -304,7 +308,8 @@ def apply( device=device) ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, E, N, K) + problem_sizes2, a_map, c_map, global_num_experts, + N, K) a1q = _fp8_perm(a1q, a_map) a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 943e383f3bce..2ecba0be45c5 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -111,7 +111,7 @@ def _moe_unpermute_and_reduce( reduction on the hidden states. """ M, topk = topk_weight.shape - K = curr_hidden.shape[1] + K = curr_hidden.shape[-1] if inv_perm is not None: curr_hidden = curr_hidden[inv_perm, ...] curr_hidden = curr_hidden.view(-1, topk, K) @@ -336,16 +336,22 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self): super().__init__() self.block_shape = deep_gemm_block_shape() - self.out_dtype = torch.bfloat16 - def workspace_shapes(self, M: int, N: int, K: int, topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes( + self, + a_dtype: torch.dtype, + M: int, + N: int, + K: int, + topk: int, + num_experts: int + ) -> Tuple[int, int, torch.dtype]: block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) workspace1 = M_sum * max(N * 2, K) workspace2 = M_sum * N - return (workspace1, workspace2, self.out_dtype) + return (workspace1, workspace2, a_dtype) def apply( self, @@ -354,9 +360,12 @@ def apply( w2: torch.Tensor, topk_ids: torch.Tensor, activation: str, + global_num_experts: int, expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -365,18 +374,16 @@ def apply( import deep_gemm as dg # TODO: chunking in here or in FusedMoEModularKernel? ignore for now - #E, N, _ = w1.shape - #_, K, _ = w2.shape - E, N, K = w1.shape + _, N, K = w1.shape + assert global_num_experts != -1 assert w2.shape[1] == K - assert w2.shape[0] == E a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute( a1q, a1q_scale, topk_ids, - E, + global_num_experts, expert_map, self.block_shape[0], ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d6fbfdc61fc5..d14ab5e1283c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -10,6 +10,7 @@ import triton.language as tl import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( @@ -20,6 +21,8 @@ per_token_group_quant_fp8) from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) +from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, + _resize_cache) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -1152,6 +1155,30 @@ def fused_experts(hidden_states: torch.Tensor, a1_scale=a1_scale, a2_scale=a2_scale, ) + elif hidden_states.shape[0] <= envs.VLLM_FUSED_MOE_CHUNK_SIZE: + fe = modular_triton_fused_moe( + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + block_shape, + ) + return fe( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, @@ -1159,6 +1186,7 @@ def fused_experts(hidden_states: torch.Tensor, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, + inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=use_fp8_w8a8, @@ -1540,3 +1568,233 @@ def fused_moe( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape) + + +class TritonDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + + def __init__(self, use_fp8_w8a8: bool, block_shape: Optional[List[int]]): + super().__init__() + self.use_fp8_w8a8 = use_fp8_w8a8 + self.block_shape = block_shape + + def dispatch( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + if self.use_fp8_w8a8: + a1q, a1q_scale = _fp8_quantize( + a1, + a1_scale, + self.block_shape, + ) + else: + a1q = a1 + a1q_scale = a1_scale + + return a1q, a1q_scale, topk_ids + + def combine( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + ) -> None: + M, topk = topk_weights.shape + K = fused_expert_output.shape[-1] + fused_expert_output = fused_expert_output.view(-1, topk, K) + fused_expert_output.mul_(topk_weights.view(M, -1, 1)) + ops.moe_sum(fused_expert_output, output) + + +class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__( + self, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]], + ): + super().__init__() + self.use_fp8_w8a8 = use_fp8_w8a8 + self.use_int4_w4a16 = use_int4_w4a16 + self.use_int8_w8a16 = use_int8_w8a16 + self.block_shape = block_shape + + def workspace_shapes( + self, + a_dtype: torch.dtype, + M: int, + N: int, + K: int, + topk: int, + num_experts: int + ) -> Tuple[int, int, torch.dtype]: + workspace1 = M * topk * max(N * 2, K) + workspace2 = M * topk * N + return (workspace1, workspace2, a_dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: + # Check constraints. + if self.use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[ + 2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + ] + + num_tokens, _ = hidden_states.shape + E, N, _ = w1.shape + K = w2.shape[1] + if global_num_experts == -1: + global_num_experts = E + top_k_num = topk_ids.shape[1] + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + M = num_tokens + config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + dtype=hidden_states.dtype) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + top_k_num, + config_dtype, + block_shape=self.block_shape, + ) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + elif hidden_states.dtype == torch.float8_e4m3fn: + compute_type = tl.bfloat16 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") + + curr_hidden_states = hidden_states + tokens_in_chunk, _ = curr_hidden_states.shape + + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + intermediate_cache1 = _resize_cache(workspace13, (tokens_in_chunk, top_k_num, N)) + intermediate_cache2 = _resize_cache(workspace2, (tokens_in_chunk * top_k_num, N // 2)) + intermediate_cache3 = _resize_cache(workspace13, (tokens_in_chunk, top_k_num, K)) + + config = get_config_func(tokens_in_chunk) + + curr_topk_ids = topk_ids + + qcurr_hidden_states, a1q_scale = hidden_states, a1q_scale + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) + + invoke_fused_moe_kernel(qcurr_hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + block_shape=self.block_shape) + + if activation == "silu": + torch.ops._C.silu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + a2q_scale: Optional[torch.Tensor] = None + + if self.use_fp8_w8a8: + qintermediate_cache2, a2q_scale = _fp8_quantize( + intermediate_cache2, a2_scale, self.block_shape) + else: + qintermediate_cache2 = intermediate_cache2 + a2q_scale = a2_scale + + invoke_fused_moe_kernel( + qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + block_shape=self.block_shape) + + return intermediate_cache3 + + +def modular_triton_fused_moe( + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, +) -> mk.FusedMoEModularKernel: + return mk.FusedMoEModularKernel( + TritonDispatchCombine(use_fp8_w8a8, block_shape), + TritonExperts( + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + block_shape, + ), + ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 3bef7ee30d16..08a004f75656 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -45,8 +45,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC): # pass @abstractmethod - def workspace_shapes(self, M: int, N: int, K: int, topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes( + self, + a_dtype: torch.dtype, + M: int, + N: int, + K: int, + topk: int, + num_experts: int + ) -> Tuple[int, int, torch.dtype]: raise NotImplementedError @abstractmethod @@ -57,9 +64,12 @@ def apply( w2: torch.Tensor, topk_ids: torch.Tensor, activation: str, + global_num_experts: int, expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -100,7 +110,9 @@ def forward( a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: M, _ = a1.shape - E, K, N = w2.shape + E, N, _ = w1.shape + K = w2.shape[1] + #E, K, N = w2.shape if global_num_experts == -1: global_num_experts = E top_k = topk_ids.shape[1] @@ -108,8 +120,15 @@ def forward( output = a1 if inplace else torch.empty_like(a1) workspace13_shape, workspace2_shape, workspace_dtype = ( - self.fused_experts.workspace_shapes(M, N, K, top_k, - global_num_experts)) + self.fused_experts.workspace_shapes( + a1.dtype, + M, + N, + K, + top_k, + global_num_experts + ) + ) # We can reuse the memory between cache1 and cache3 because by the time # we need cache3, we're done with cache1 @@ -135,9 +154,12 @@ def forward( w2, dispatched_topk_ids, activation, + global_num_experts, expert_map, w1_scale, w2_scale, + w1_zp, + w2_zp, a1q_scale, a2_scale, workspace13=workspace13, From bcac19a7209227875f14a887e882b9044df5d4ec Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 05:33:17 +0000 Subject: [PATCH 012/171] fix outplace bug Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d14ab5e1283c..f86a22e5766c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1186,7 +1186,6 @@ def fused_experts(hidden_states: torch.Tensor, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=use_fp8_w8a8, From e2ab4f581d8f27adfe3f804c491bdb3219c94d82 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 17:17:16 +0000 Subject: [PATCH 013/171] refactor dispatch/combine stuff Signed-off-by: Bill Nell --- .../layers/fused_moe/cutlass_moe.py | 58 +--------- .../layers/fused_moe/deep_gemm_moe.py | 108 ++---------------- .../layers/fused_moe/dispatch_combine.py | 44 +++++++ .../layers/fused_moe/modular_kernel.py | 7 +- .../layers/fused_moe/moe_permute_unpermute.py | 68 +++++++++++ .../layers/fused_moe/pplx_dispatch_combine.py | 64 +++++++++++ vllm/model_executor/layers/fused_moe/utils.py | 9 +- 7 files changed, 202 insertions(+), 156 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/dispatch_combine.py create mode 100644 vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py create mode 100644 vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 75d25f418d1e..77fa6daec95c 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -7,6 +7,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine +) #TODO make the grouped gemm kernel consistent with scaled gemm kernel @@ -183,59 +186,6 @@ def cutlass_moe_fp8( return c2.sum(dim=1) -class CutlassDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - - def __init__(self, out_dtype: torch.dtype): - super().__init__() - self.out_dtype = out_dtype - - def dispatch( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: - # why do we need to check a2_scale here? - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) - - a1q, a1q_scale = ops.scaled_fp8_quant( - a1, a1_scale, use_per_token_if_dynamic=per_act_token) - - return a1q, a1_scale, topk_ids - - def combine( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - ) -> None: - M, topk = topk_weights.shape - K = fused_expert_output.shape[-1] - fused_expert_output = (fused_expert_output.view(-1, topk, K) * - topk_weights.view(M, -1, 1)) - assert output.dtype == self.out_dtype - ops.moe_sum(fused_expert_output, output) - - ops.get_cutlass_moe_mm_data(topk_ids, - expert_offsets, - problem_sizes1, - problem_sizes2, - a_map, - c_map, - num_experts, - k, - n) - - rep_a_q = _fp8_perm(a_q, a_map) - rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale - - return rep_a_q, rep_a1_scales, expert_offsets, c_map, (problem_sizes1, problem_sizes2) - - class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -350,7 +300,7 @@ def modular_cutlass_moe_fp8( out_dtype: torch.dtype = torch.half, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( - CutlassDispatchCombine(out_dtype), + StandardDispatchCombine(), CutlassExperts( ab_strides1, c_strides1, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 2ecba0be45c5..550a81536930 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -6,13 +6,16 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, - _fp8_quantize, +from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + _moe_permute, + _moe_unpermute_and_reduce +) +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine +) from vllm.utils import round_up logger = init_logger(__name__) @@ -58,67 +61,6 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, and w2.is_contiguous()) -def _moe_permute( - curr_hidden_states: torch.Tensor, - a1q_scale: Optional[torch.Tensor], - curr_topk_ids: torch.Tensor, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - block_m: int, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, - Optional[torch.Tensor]]: - """ - Determine the sorted_token_ids, expert_ids for the given problem size. - Permute the hidden states and scales according to `sorted_token_ids`. - """ - top_k_num = curr_topk_ids.shape[1] - - tokens_in_chunk, _ = curr_hidden_states.shape - - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, - block_m, - global_num_experts, - expert_map, - pad_sorted_ids=True)) - - inv_perm: Optional[torch.Tensor] = None - - num_tokens = top_k_num * tokens_in_chunk - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) - inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] - - # Permute according to sorted token ids. - curr_hidden_states = _fp8_perm(curr_hidden_states, - sorted_token_ids // top_k_num) - - if a1q_scale is not None: - a1q_scale = a1q_scale[sorted_token_ids // top_k_num] - - return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) - - -def _moe_unpermute_and_reduce( - out: torch.Tensor, - curr_hidden: torch.Tensor, - inv_perm: Optional[torch.Tensor], - topk_weight: torch.Tensor, -) -> None: - """ - Unpermute the final result and apply topk_weights, then perform the final - reduction on the hidden states. - """ - M, topk = topk_weight.shape - K = curr_hidden.shape[-1] - if inv_perm is not None: - curr_hidden = curr_hidden[inv_perm, ...] - curr_hidden = curr_hidden.view(-1, topk, K) - curr_hidden.mul_(topk_weight.view(M, -1, 1)) - ops.moe_sum(curr_hidden, out) - - def deep_gemm_moe_fp8( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -299,38 +241,6 @@ def deep_gemm_moe_fp8( return out_hidden_states -class DeepGemmDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - - def __init__(self): - super().__init__() - self.block_shape = deep_gemm_block_shape() - - def dispatch( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: - a1q, a1q_scale = _fp8_quantize( - a1, - a1_scale, - self.block_shape, - ) - return a1q, a1q_scale, topk_ids - - def combine( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - ) -> None: - _moe_unpermute_and_reduce(output, fused_expert_output, None, - topk_weights) - - class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self): @@ -418,6 +328,6 @@ def apply( def modular_deep_gemm_fused_moe_fp8() -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( - DeepGemmDispatchCombine(), + StandardDispatchCombine(deep_gemm_block_shape()), DeepGemmExperts(), ) diff --git a/vllm/model_executor/layers/fused_moe/dispatch_combine.py b/vllm/model_executor/layers/fused_moe/dispatch_combine.py new file mode 100644 index 000000000000..589955fb65d1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/dispatch_combine.py @@ -0,0 +1,44 @@ +import torch +from typing import Optional, Tuple + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + _moe_unpermute_and_reduce +) + +class StandardDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + + def __init__(self, block_shape: Optional[list[int]] = None): + super().__init__() + self.block_shape = block_shape + + def dispatch( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + a1q, a1q_scale = _fp8_quantize( + a1, + a1_scale, + self.block_shape, + per_act_token, + ) + return a1q, a1q_scale, topk_ids + + def combine( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + ) -> None: + _moe_unpermute_and_reduce(output, fused_expert_output, None, + topk_weights) + diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 08a004f75656..b7582bcb4fe2 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -109,10 +109,15 @@ def forward( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + # Note: extracting the problem shape from the weight and activation tensors is + # tricky. It needs to be done this way specifically due to subtle issues with + # particular kernels, e.g. the int4 kernels divide the trailing dimension by + # two, so it's not "correct" to extract N or K from the trailing dimension of + # w1 or w2. Similarly, some kernels transpose the weights, so this needs to + # be kept in mind. M, _ = a1.shape E, N, _ = w1.shape K = w2.shape[1] - #E, K, N = w2.shape if global_num_experts == -1: global_num_experts = E top_k = topk_ids.shape[1] diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py new file mode 100644 index 000000000000..60e1877ad865 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -0,0 +1,68 @@ +import torch +from typing import Optional, Tuple + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) +from vllm.model_executor.layers.fused_moe.utils import _fp8_perm + + +def _moe_permute( + curr_hidden_states: torch.Tensor, + a1q_scale: Optional[torch.Tensor], + curr_topk_ids: torch.Tensor, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + block_m: int, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: + """ + Determine the sorted_token_ids, expert_ids for the given problem size. + Permute the hidden states and scales according to `sorted_token_ids`. + """ + top_k_num = curr_topk_ids.shape[1] + + tokens_in_chunk, _ = curr_hidden_states.shape + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, + block_m, + global_num_experts, + expert_map, + pad_sorted_ids=True)) + + inv_perm: Optional[torch.Tensor] = None + + num_tokens = top_k_num * tokens_in_chunk + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) + expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) + inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] + + # Permute according to sorted token ids. + curr_hidden_states = _fp8_perm(curr_hidden_states, + sorted_token_ids // top_k_num) + + if a1q_scale is not None: + a1q_scale = a1q_scale[sorted_token_ids // top_k_num] + + return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) + + +def _moe_unpermute_and_reduce( + out: torch.Tensor, + curr_hidden: torch.Tensor, + inv_perm: Optional[torch.Tensor], + topk_weight: torch.Tensor, +) -> None: + """ + Unpermute the final result and apply topk_weights, then perform the final + reduction on the hidden states. + """ + M, topk = topk_weight.shape + K = curr_hidden.shape[-1] + if inv_perm is not None: + curr_hidden = curr_hidden[inv_perm, ...] + curr_hidden = curr_hidden.view(-1, topk, K) + curr_hidden.mul_(topk_weight.view(M, -1, 1)) + ops.moe_sum(curr_hidden, out) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py new file mode 100644 index 000000000000..1eb500d932a1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -0,0 +1,64 @@ +import torch +from typing import Optional, Tuple + +import pplx_kernels as pplx +import vllm.model_executor.layers.fused_moe.modular_kernel as mk + + +class PplxDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + def __init__(self, a2a: pplx.AllToAll): + super().__init__() + self.a2a = a2a + + def dispatch( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + self.a2a.dispatch( + out_expert_num_tokens, # torch.Tensor, + out_expert_x, # torch.Tensor, + out_expert_x_scale, # torch.Tensor | None, + dp_x, # torch.Tensor, + dp_x_scale, # torch.Tensor | None, + indices, # torch.Tensor, + bound_m, # torch.Tensor | None, + do_send, # bool = True, + do_recv, # bool = True, + ) + return 1q, a1q_scale, topk_ids + + def combine( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + ) -> None: + self.a2a.combine( + out_tokens, #: torch.Tensor, + indices, #: torch.Tensor, + weights, #: torch.Tensor, + expert_y, #: torch.Tensor, + bound_m, #: torch.Tensor | None, + do_send, #: bool = True, + do_recv, #: bool = True, + ) + + +# singleton-ish +def get_a2a( + max_num_tokens: int, + num_experts: int, + experts_per_token: int, + rank: int, + world_size: int, + dp_size: int, + hidden_dim: int, + hidden_dim_bytes: int, + hidden_dim_scale_bytes: int, +) -> pplx.AllToAll: + pass diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index db31422f7275..ee8e8857fabd 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -22,14 +22,19 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: def _fp8_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], - block_shape: Optional[List[int]], + block_shape: Optional[List[int]] = None, + per_act_token: bool = False, # make sure this is the same default as op ) -> Tuple[torch.Tensor, torch.Tensor]: """ Perform fp8 quantization on the inputs. If a block_shape is provided, the output will be blocked. """ if block_shape is None: - A, A_scale = ops.scaled_fp8_quant(A, A_scale) + A, A_scale = ops.scaled_fp8_quant( + A, + A_scale, + use_per_token_if_dynamic=per_act_token + ) else: assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] From ecaca4e3dd8b6ed0920ce2ae01a2fabffcd732ad Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 19:39:35 +0000 Subject: [PATCH 014/171] initial pplx dispatch/combine class Signed-off-by: Bill Nell --- .../layers/fused_moe/dispatch_combine.py | 6 +- .../layers/fused_moe/fused_moe.py | 6 +- .../layers/fused_moe/modular_kernel.py | 20 ++- .../layers/fused_moe/pplx_dispatch_combine.py | 114 ++++++++++++------ 4 files changed, 92 insertions(+), 54 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/dispatch_combine.py b/vllm/model_executor/layers/fused_moe/dispatch_combine.py index 589955fb65d1..cd981cfb6961 100644 --- a/vllm/model_executor/layers/fused_moe/dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/dispatch_combine.py @@ -21,7 +21,7 @@ def dispatch( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) @@ -31,14 +31,14 @@ def dispatch( self.block_shape, per_act_token, ) - return a1q, a1q_scale, topk_ids + return a1q, a1q_scale def combine( self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> None: _moe_unpermute_and_reduce(output, fused_expert_output, None, topk_weights) - diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f86a22e5766c..4e84436bbae4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1569,6 +1569,7 @@ def fused_moe( block_shape=block_shape) +# TODO: merge with StandardDispatchCombine class TritonDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): def __init__(self, use_fp8_w8a8: bool, block_shape: Optional[List[int]]): @@ -1584,7 +1585,7 @@ def dispatch( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if self.use_fp8_w8a8: a1q, a1q_scale = _fp8_quantize( a1, @@ -1595,13 +1596,14 @@ def dispatch( a1q = a1 a1q_scale = a1_scale - return a1q, a1q_scale, topk_ids + return a1q, a1q_scale def combine( self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> None: M, topk = topk_weights.shape K = fused_expert_output.shape[-1] diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index b7582bcb4fe2..6ff85c21ceec 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -9,9 +9,6 @@ class FusedMoEQuantizeDispatchCombine(ABC): - # def __init__(self): - # pass - @abstractmethod def dispatch( self, @@ -21,11 +18,9 @@ def dispatch( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: - # TODO: figure this out + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # returns (quantized+dispatched a, - # quantized+dispatched a1_scales, - # dispatched topk_ids) + # quantized+dispatched a1_scales) raise NotImplementedError @abstractmethod @@ -34,6 +29,7 @@ def combine( output: torch.Tensor, fused_expert_output: torch.Tensor, # not reduced or weighted topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> None: raise NotImplementedError @@ -41,9 +37,6 @@ def combine( # store weights, etc. here class FusedMoEPermuteExpertsUnpermute(ABC): - # def __init__(self): - # pass - @abstractmethod def workspace_shapes( self, @@ -115,6 +108,7 @@ def forward( # two, so it's not "correct" to extract N or K from the trailing dimension of # w1 or w2. Similarly, some kernels transpose the weights, so this needs to # be kept in mind. + # TODO: make this a method/utility function, e.g. problem_size(a, w1, w2, topk_ids, ...) M, _ = a1.shape E, N, _ = w1.shape K = w2.shape[1] @@ -144,7 +138,7 @@ def forward( device=a1.device, dtype=workspace_dtype) - a1q, a1q_scale, dispatched_topk_ids = self.dispatch_combine.dispatch( + a1q, a1q_scale = self.dispatch_combine.dispatch( a1, a1_scale, a2_scale, @@ -157,7 +151,7 @@ def forward( a1q, w1, w2, - dispatched_topk_ids, + topk_ids, activation, global_num_experts, expert_map, @@ -171,6 +165,6 @@ def forward( workspace2=workspace2, ) - self.dispatch_combine.combine(output, fused_out, topk_weights) + self.dispatch_combine.combine(output, fused_out, topk_weights, topk_ids) return output diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 1eb500d932a1..fea0c5c1f16c 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -1,64 +1,106 @@ import torch -from typing import Optional, Tuple +from typing import List, Optional, Tuple import pplx_kernels as pplx import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize +# Note use: layer.get_all_to_all() to get an AllToAll instance +# The max_num_tokens, world_size and dp_size must be the same +# as the ones used to create the AllToAll. Unfortunately, there's +# no way(?) to extract this info from AllToAll class PplxDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__(self, a2a: pplx.AllToAll): + def __init__( + self, + a2a: pplx.AllToAll, + max_num_tokens: int, + world_size: int, + dp_size: int, + block_shape: Optional[List[int]] = None): super().__init__() self.a2a = a2a + self.block_shape = block_shape + self.dp_num_tokens = max_num_tokens * (world_size // dp_size) def dispatch( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, + rank_topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Is this always going to be a1.device? + device = a1.device + + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + a1q, a1q_scale = _fp8_quantize( + a1, + a1_scale, + self.block_shape, + per_act_token, + ) + + expert_num_tokens = torch.empty( + num_experts, + dtype=torch.int32, + device=device, + ) + + expert_x = torch.empty( + (num_experts, self.dp_num_tokens, a1q.shape[-1]), + dtype=a1q.dtype, + device=device, + ) + + expert_x_scale: torch.Tensor | None = None + if a1q.dtype.itemsize == 1: + float32_size = torch.float32.itemsize + block_size = (self.block_shape[0] if self.block_shape is not None else 1) * float32_size + expert_x_scale = torch.empty( + ( + num_experts, + expert_x.size(1), + (expert_x.size(2) + block_size - 1) // block_size, + ), + dtype=torch.float32, + device=device, + ) + + # This argument is optional + bound_m = torch.tensor([a1q.shape[0]], dtype=torch.uint32, device=device) + self.a2a.dispatch( - out_expert_num_tokens, # torch.Tensor, - out_expert_x, # torch.Tensor, - out_expert_x_scale, # torch.Tensor | None, - dp_x, # torch.Tensor, - dp_x_scale, # torch.Tensor | None, - indices, # torch.Tensor, - bound_m, # torch.Tensor | None, - do_send, # bool = True, - do_recv, # bool = True, + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=rank_topk_ids, + bound_m=bound_m, ) - return 1q, a1q_scale, topk_ids + return expert_x, expert_x_scale def combine( self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> None: - self.a2a.combine( - out_tokens, #: torch.Tensor, - indices, #: torch.Tensor, - weights, #: torch.Tensor, - expert_y, #: torch.Tensor, - bound_m, #: torch.Tensor | None, - do_send, #: bool = True, - do_recv, #: bool = True, - ) + # This argument is optional + bound_m = torch.tensor([output.shape[0]], dtype=torch.uint32, device=output.device) + # TODO assert output is the proper size -# singleton-ish -def get_a2a( - max_num_tokens: int, - num_experts: int, - experts_per_token: int, - rank: int, - world_size: int, - dp_size: int, - hidden_dim: int, - hidden_dim_bytes: int, - hidden_dim_scale_bytes: int, -) -> pplx.AllToAll: - pass + self.a2a.combine( + out_tokens=output, + indices=topk_ids, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m + ) From 9bcbde0edf905318dbd821eb6ba449969d9a2f60 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 20:41:40 +0000 Subject: [PATCH 015/171] merge triton dispatch into standard, add some comments Signed-off-by: Bill Nell --- .../layers/fused_moe/cutlass_moe.py | 2 +- .../layers/fused_moe/deep_gemm_moe.py | 3 +- .../layers/fused_moe/dispatch_combine.py | 28 ++- .../layers/fused_moe/fused_moe.py | 51 +----- .../layers/fused_moe/modular_kernel.py | 172 ++++++++++++++++-- .../layers/fused_moe/pplx_dispatch_combine.py | 21 ++- 6 files changed, 196 insertions(+), 81 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 77fa6daec95c..7ea999d5086d 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -300,7 +300,7 @@ def modular_cutlass_moe_fp8( out_dtype: torch.dtype = torch.half, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( - StandardDispatchCombine(), + StandardDispatchCombine(quant_dtype=torch.float8_e4m3fn), CutlassExperts( ab_strides1, c_strides1, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 550a81536930..19c54dd2c31e 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -328,6 +328,7 @@ def apply( def modular_deep_gemm_fused_moe_fp8() -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( - StandardDispatchCombine(deep_gemm_block_shape()), + StandardDispatchCombine(quant_dtype=torch.float8_e4m3fn, + block_shape=deep_gemm_block_shape()), DeepGemmExperts(), ) diff --git a/vllm/model_executor/layers/fused_moe/dispatch_combine.py b/vllm/model_executor/layers/fused_moe/dispatch_combine.py index cd981cfb6961..207a1c698603 100644 --- a/vllm/model_executor/layers/fused_moe/dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/dispatch_combine.py @@ -9,9 +9,14 @@ class StandardDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__(self, block_shape: Optional[list[int]] = None): + def __init__( + self, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None + ): super().__init__() self.block_shape = block_shape + self.quant_dtype = quant_dtype def dispatch( self, @@ -22,15 +27,20 @@ def dispatch( num_experts: int, expert_map: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + if self.quant_dtype == torch.float8_e4m3fn: + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + a1q, a1q_scale = _fp8_quantize( + a1, + a1_scale, + self.block_shape, + per_act_token, + ) + else: + a1q = a1 + a1q_scale = a1_scale - a1q, a1q_scale = _fp8_quantize( - a1, - a1_scale, - self.block_shape, - per_act_token, - ) return a1q, a1q_scale def combine( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4e84436bbae4..c0f4a39da322 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -15,6 +15,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, deep_gemm_moe_fp8) +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine +) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -1569,49 +1572,6 @@ def fused_moe( block_shape=block_shape) -# TODO: merge with StandardDispatchCombine -class TritonDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - - def __init__(self, use_fp8_w8a8: bool, block_shape: Optional[List[int]]): - super().__init__() - self.use_fp8_w8a8 = use_fp8_w8a8 - self.block_shape = block_shape - - def dispatch( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if self.use_fp8_w8a8: - a1q, a1q_scale = _fp8_quantize( - a1, - a1_scale, - self.block_shape, - ) - else: - a1q = a1 - a1q_scale = a1_scale - - return a1q, a1q_scale - - def combine( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ) -> None: - M, topk = topk_weights.shape - K = fused_expert_output.shape[-1] - fused_expert_output = fused_expert_output.view(-1, topk, K) - fused_expert_output.mul_(topk_weights.view(M, -1, 1)) - ops.moe_sum(fused_expert_output, output) - - class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, @@ -1791,7 +1751,10 @@ def modular_triton_fused_moe( block_shape: Optional[List[int]] = None, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( - TritonDispatchCombine(use_fp8_w8a8, block_shape), + StandardDispatchCombine( + quant_dtype=torch.float8_e4m3fn if use_fp8_w8a8 else None, + block_shape=block_shape + ), TritonExperts( use_fp8_w8a8, use_int8_w8a16, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 6ff85c21ceec..7f617a06e2d5 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,11 +4,51 @@ import torch -# TODO: add comments +def moe_problem_size( + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, +) -> Tuple[int, int, int, int]: + """ + Extract the MoE problem size from the given tensor arguments: + - a: The hidden states, input to the MoE layer. + - w1: The first set of expert weights. + - w2: The second set of expert weights. + - topk_ids: The topk ids. + Note: extracting the problem shape from the weight and activation tensors is + not obvious. It needs to be done this way specifically due to subtle issues + with particular kernels, e.g. the int4 kernels divide the trailing dimension + by two, so it's not "correct" to extract N or K from the trailing dimension + of w1 or w2. Similarly, some kernels transpose the weights, so this needs to + be kept in mind. + """ + # Make sure we are using the correct a1 (pre-permute) + assert topk_ids.shape[0] == a1.shape[0] + M, _ = a1.shape + E, N, _ = w1.shape + K = w2.shape[1] + topk = topk_ids.shape[1] + return E, M, N, K, topk -class FusedMoEQuantizeDispatchCombine(ABC): +# +# A set of base classes used to make MoE kernels more modular. +# +# Architecture: +# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine] +# +# [Quantize-Dispatch] and [Combine] functionality are bundled into a single +# class `FusedMoEQuantizeDispatchCombine` since they could use collective +# communication mechanisms that need to be consistent. +# +# Ideal architecture: +# [Router] → [Quantize-Dispatch-Permute] → [Experts] → [Unpermute-Combine] +# +class FusedMoEQuantizeDispatchCombine(ABC): + """ + """ @abstractmethod def dispatch( self, @@ -19,22 +59,43 @@ def dispatch( num_experts: int, expert_map: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # returns (quantized+dispatched a, - # quantized+dispatched a1_scales) + """ + Perform any quantization (and/or) dispatching needed + for this kernel. + - a1: The (unquantized) input to the MoE layer. + - a1_scale: Optional scales for a1 + - a2_scale: Optional scales for the second MoE gemm. Required to make sure the quantization is consistent for both gemms. + - topk_ids: The topk_ids. + - num_experts: The total number of experts in the global expert space. + - expert_map: A tensor mapping expert indices from the global expert + space to the local expert space of the expert parallel shard. + + Returns a tuple of: + - quantized + dispatched a. + - quantized + dispatched a1_scales. + """ raise NotImplementedError @abstractmethod def combine( self, output: torch.Tensor, - fused_expert_output: torch.Tensor, # not reduced or weighted + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ) -> None: + """ + Perform any combine plus apply weights and perform a reduction on the + fused experts output. + - output: The output tensor, written in place. Must be (M, K) shape. + - fused_expert_output: The unweighted, unreduced output of the fused + experts, it will have (M, topk, K) shape. + - topk_weights: The weights to be applied to the fused_experts_output. + - topk_ids: The topk_ids. + """ raise NotImplementedError -# store weights, etc. here class FusedMoEPermuteExpertsUnpermute(ABC): @abstractmethod @@ -47,6 +108,19 @@ def workspace_shapes( topk: int, num_experts: int ) -> Tuple[int, int, torch.dtype]: + """ + Compute the number of elements for the temporary outputs of the two + gemms and activation in the fused expert function. Since the + gemms are independent, the workspace for the first gemm can be shared + with the workspace for the last gemm. + + Returns a tuple of: + - Number of workspace13 elements: must be large enough to hold the result + of either expert gemm. + - Number of workspace2 elements: must be large enough to hold the result + of the activation function. + - Workspace type: The dtype to use for the workspace tensors. + """ raise NotImplementedError @abstractmethod @@ -68,6 +142,42 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, ) -> torch.Tensor: + """ + This function computes the intermediate result of a Mixture of Experts (MoE) + layer using two sets of weights, w1 and w2. + + Parameters: + - a1q: (torch.Tensor): The (quantized) input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - topk_ids (torch.Tensor): A map of row to expert id. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for + w1. + - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for + w2. + - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs + must be large enough to hold output of either MoE gemm. + - workspace2 (torch.Tensor): A scratch tensor used for the activation + function. + + Returns: + - torch.Tensor: The unweighted, unreduced output tensor + """ raise NotImplementedError @@ -86,7 +196,7 @@ def __init__( def forward( self, - a1: torch.Tensor, # aka hidden states + a1: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, @@ -102,19 +212,45 @@ def forward( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # Note: extracting the problem shape from the weight and activation tensors is - # tricky. It needs to be done this way specifically due to subtle issues with - # particular kernels, e.g. the int4 kernels divide the trailing dimension by - # two, so it's not "correct" to extract N or K from the trailing dimension of - # w1 or w2. Similarly, some kernels transpose the weights, so this needs to - # be kept in mind. - # TODO: make this a method/utility function, e.g. problem_size(a, w1, w2, topk_ids, ...) - M, _ = a1.shape - E, N, _ = w1.shape - K = w2.shape[1] + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - a1: (torch.Tensor): The input tensor to the MoE layer (aka hidden_states). + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - topk_weights (torch.Tensor): The topk weights applied at the end of the layer. + - topk_ids (torch.Tensor): A map of row to expert id. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for + w1. + - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for + w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + E, M, N, K, top_k = moe_problem_size(a1, w1, w2, topk_ids) + if global_num_experts == -1: global_num_experts = E - top_k = topk_ids.shape[1] output = a1 if inplace else torch.empty_like(a1) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index fea0c5c1f16c..3bc6b50720cb 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -17,6 +17,7 @@ def __init__( max_num_tokens: int, world_size: int, dp_size: int, + quant_dtype: Optional[torch.dtype] = None, block_shape: Optional[List[int]] = None): super().__init__() self.a2a = a2a @@ -35,15 +36,19 @@ def dispatch( # Is this always going to be a1.device? device = a1.device - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + if self.quant_dtype == torch.float8_e4m3fn: + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) - a1q, a1q_scale = _fp8_quantize( - a1, - a1_scale, - self.block_shape, - per_act_token, - ) + a1q, a1q_scale = _fp8_quantize( + a1, + a1_scale, + self.block_shape, + per_act_token, + ) + else: + a1q = a1 + a1q_scale = a1_scale expert_num_tokens = torch.empty( num_experts, From 73847e011818fb5567d29a6e823bc91335dace2b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 20:47:41 +0000 Subject: [PATCH 016/171] format Signed-off-by: Bill Nell --- tests/kernels/quantization/test_block_fp8.py | 4 +- .../layers/fused_moe/cutlass_moe.py | 36 +++--- .../layers/fused_moe/deep_gemm_moe.py | 25 ++-- .../layers/fused_moe/dispatch_combine.py | 21 ++-- .../layers/fused_moe/fused_moe.py | 84 +++++++------- .../layers/fused_moe/modular_kernel.py | 109 ++++++++---------- .../layers/fused_moe/moe_permute_unpermute.py | 4 +- .../layers/fused_moe/pplx_dispatch_combine.py | 46 ++++---- vllm/model_executor/layers/fused_moe/utils.py | 5 +- 9 files changed, 159 insertions(+), 175 deletions(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 70404dbe49a1..6691d28c83d8 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,9 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - deep_gemm_moe_fp8, - modular_deep_gemm_fused_moe_fp8, - _valid_deep_gemm_shape) + _valid_deep_gemm_shape, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 7ea999d5086d..c6b50729b246 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -6,10 +6,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache from vllm.model_executor.layers.fused_moe.dispatch_combine import ( - StandardDispatchCombine -) + StandardDispatchCombine) +from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache #TODO make the grouped gemm kernel consistent with scaled gemm kernel @@ -204,14 +203,13 @@ def __init__( self.out_dtype = out_dtype def workspace_shapes( - self, - a_dtype: torch.dtype, - M: int, - K: int, # Note that K, N are transposed - N: int, - topk: int, - num_experts: int - ) -> Tuple[int, int, torch.dtype]: + self, + a_dtype: torch.dtype, + M: int, + K: int, # Note that K, N are transposed + N: int, + topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: workspace1 = M * topk * max(2 * N, K) workspace2 = M * topk * N return (workspace1, workspace2, self.out_dtype) @@ -246,9 +244,15 @@ def apply( per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - expert_offsets = torch.empty((global_num_experts + 1), dtype=torch.int32, device=device) - problem_sizes1 = torch.empty((global_num_experts, 3), dtype=torch.int32, device=device) - problem_sizes2 = torch.empty((global_num_experts, 3), dtype=torch.int32, device=device) + expert_offsets = torch.empty((global_num_experts + 1), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((global_num_experts, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((global_num_experts, 3), + dtype=torch.int32, + device=device) a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, @@ -258,8 +262,8 @@ def apply( device=device) ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, global_num_experts, - N, K) + problem_sizes2, a_map, c_map, + global_num_experts, N, K) a1q = _fp8_perm(a1q, a_map) a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 19c54dd2c31e..6ffb40cb52cb 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -7,15 +7,12 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine) +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + _moe_permute, _moe_unpermute_and_reduce) from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) -from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - _moe_permute, - _moe_unpermute_and_reduce -) -from vllm.model_executor.layers.fused_moe.dispatch_combine import ( - StandardDispatchCombine -) from vllm.utils import round_up logger = init_logger(__name__) @@ -32,7 +29,7 @@ def deep_gemm_block_shape() -> list[int]: def _valid_deep_gemm_shape(M: int, N: int, K: int): align = deep_gemm_block_shape()[0] - return M >= align and N % align == 0 and K % align == 0 + return align <= M and N % align == 0 and K % align == 0 # TODO: check types? @@ -247,15 +244,9 @@ def __init__(self): super().__init__() self.block_shape = deep_gemm_block_shape() - def workspace_shapes( - self, - a_dtype: torch.dtype, - M: int, - N: int, - K: int, - topk: int, - num_experts: int - ) -> Tuple[int, int, torch.dtype]: + def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, + topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) diff --git a/vllm/model_executor/layers/fused_moe/dispatch_combine.py b/vllm/model_executor/layers/fused_moe/dispatch_combine.py index 207a1c698603..06b90c350252 100644 --- a/vllm/model_executor/layers/fused_moe/dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/dispatch_combine.py @@ -1,19 +1,19 @@ -import torch +# SPDX-License-Identifier: Apache-2.0 from typing import Optional, Tuple +import torch + import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - _moe_unpermute_and_reduce -) + _moe_unpermute_and_reduce) +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize + class StandardDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__( - self, - quant_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None - ): + def __init__(self, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None): super().__init__() self.block_shape = block_shape self.quant_dtype = quant_dtype @@ -28,7 +28,8 @@ def dispatch( expert_map: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if self.quant_dtype == torch.float8_e4m3fn: - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + per_act_token = a1_scale.numel( + ) != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) a1q, a1q_scale = _fp8_quantize( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c0f4a39da322..daae08f78b6e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -16,8 +16,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.dispatch_combine import ( - StandardDispatchCombine -) + StandardDispatchCombine) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -1573,6 +1572,7 @@ def fused_moe( class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__( self, use_fp8_w8a8: bool, @@ -1586,15 +1586,9 @@ def __init__( self.use_int8_w8a16 = use_int8_w8a16 self.block_shape = block_shape - def workspace_shapes( - self, - a_dtype: torch.dtype, - M: int, - N: int, - K: int, - topk: int, - num_experts: int - ) -> Tuple[int, int, torch.dtype]: + def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, + topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: workspace1 = M * topk * max(N * 2, K) workspace2 = M * topk * N return (workspace1, workspace2, a_dtype) @@ -1622,9 +1616,11 @@ def apply( assert hidden_states.shape[1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[ + 2], "Hidden size mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert hidden_states.is_contiguous( + ), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" assert hidden_states.dtype in [ @@ -1637,9 +1633,9 @@ def apply( if global_num_experts == -1: global_num_experts = E top_k_num = topk_ids.shape[1] + # We execute the fused_moe kernel in chunks to circumvent this issue: # https://github.com/vllm-project/vllm/issues/5938 - M = num_tokens config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, @@ -1663,16 +1659,20 @@ def apply( elif hidden_states.dtype == torch.float8_e4m3fn: compute_type = tl.bfloat16 else: - raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") + raise ValueError( + f"Unsupported compute_type: {hidden_states.dtype}") curr_hidden_states = hidden_states tokens_in_chunk, _ = curr_hidden_states.shape # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 - intermediate_cache1 = _resize_cache(workspace13, (tokens_in_chunk, top_k_num, N)) - intermediate_cache2 = _resize_cache(workspace2, (tokens_in_chunk * top_k_num, N // 2)) - intermediate_cache3 = _resize_cache(workspace13, (tokens_in_chunk, top_k_num, K)) + intermediate_cache1 = _resize_cache(workspace13, + (tokens_in_chunk, top_k_num, N)) + intermediate_cache2 = _resize_cache( + workspace2, (tokens_in_chunk * top_k_num, N // 2)) + intermediate_cache3 = _resize_cache(workspace13, + (tokens_in_chunk, top_k_num, K)) config = get_config_func(tokens_in_chunk) @@ -1721,40 +1721,38 @@ def apply( qintermediate_cache2 = intermediate_cache2 a2q_scale = a2_scale - invoke_fused_moe_kernel( - qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - w2_scale, - w2_zp, - None, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - block_shape=self.block_shape) + invoke_fused_moe_kernel(qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + block_shape=self.block_shape) return intermediate_cache3 def modular_triton_fused_moe( - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - block_shape: Optional[List[int]] = None, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( StandardDispatchCombine( quant_dtype=torch.float8_e4m3fn if use_fp8_w8a8 else None, - block_shape=block_shape - ), + block_shape=block_shape), TritonExperts( use_fp8_w8a8, use_int8_w8a16, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 7f617a06e2d5..196c29eca8a8 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,6 +4,7 @@ import torch + def moe_problem_size( a1: torch.Tensor, w1: torch.Tensor, @@ -21,8 +22,8 @@ def moe_problem_size( not obvious. It needs to be done this way specifically due to subtle issues with particular kernels, e.g. the int4 kernels divide the trailing dimension by two, so it's not "correct" to extract N or K from the trailing dimension - of w1 or w2. Similarly, some kernels transpose the weights, so this needs to - be kept in mind. + of w1 or w2. Similarly, some kernels transpose the weights, so this needs + to be kept in mind. """ # Make sure we are using the correct a1 (pre-permute) assert topk_ids.shape[0] == a1.shape[0] @@ -32,6 +33,7 @@ def moe_problem_size( topk = topk_ids.shape[1] return E, M, N, K, topk + # # A set of base classes used to make MoE kernels more modular. # @@ -46,9 +48,11 @@ def moe_problem_size( # [Router] → [Quantize-Dispatch-Permute] → [Experts] → [Unpermute-Combine] # + class FusedMoEQuantizeDispatchCombine(ABC): """ """ + @abstractmethod def dispatch( self, @@ -64,7 +68,8 @@ def dispatch( for this kernel. - a1: The (unquantized) input to the MoE layer. - a1_scale: Optional scales for a1 - - a2_scale: Optional scales for the second MoE gemm. Required to make sure the quantization is consistent for both gemms. + - a2_scale: Optional scales for the second MoE gemm. Required to make + sure the quantization is consistent for both gemms. - topk_ids: The topk_ids. - num_experts: The total number of experts in the global expert space. - expert_map: A tensor mapping expert indices from the global expert @@ -99,15 +104,9 @@ def combine( class FusedMoEPermuteExpertsUnpermute(ABC): @abstractmethod - def workspace_shapes( - self, - a_dtype: torch.dtype, - M: int, - N: int, - K: int, - topk: int, - num_experts: int - ) -> Tuple[int, int, torch.dtype]: + def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, + topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: """ Compute the number of elements for the temporary outputs of the two gemms and activation in the fused expert function. Since the @@ -115,10 +114,10 @@ def workspace_shapes( with the workspace for the last gemm. Returns a tuple of: - - Number of workspace13 elements: must be large enough to hold the result - of either expert gemm. - - Number of workspace2 elements: must be large enough to hold the result - of the activation function. + - Number of workspace13 elements: must be large enough to hold the + result of either expert gemm. + - Number of workspace2 elements: must be large enough to hold the + result of the activation function. - Workspace type: The dtype to use for the workspace tensors. """ raise NotImplementedError @@ -143,8 +142,8 @@ def apply( workspace2: torch.Tensor, ) -> torch.Tensor: """ - This function computes the intermediate result of a Mixture of Experts (MoE) - layer using two sets of weights, w1 and w2. + This function computes the intermediate result of a Mixture of Experts + (MoE) layer using two sets of weights, w1 and w2. Parameters: - a1q: (torch.Tensor): The (quantized) input tensor to the MoE layer. @@ -152,24 +151,21 @@ def apply( - w2 (torch.Tensor): The second set of expert weights. - topk_ids (torch.Tensor): A map of row to expert id. - activation (str): The activation function to apply after the first - MoE layer. + MoE layer. - global_num_experts (int): The total number of experts in the global - expert space. + expert space. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert - parallel shard. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for - w1. + w1. - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for - w2. - - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be used for - a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for - a2. + w2. + - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be + used for a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs must be large enough to hold output of either MoE gemm. - workspace2 (torch.Tensor): A scratch tensor used for the activation @@ -213,36 +209,33 @@ def forward( a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. + This function computes a Mixture of Experts (MoE) layer using two sets + of weights, w1 and w2, and top-k gating mechanism. Parameters: - - a1: (torch.Tensor): The input tensor to the MoE layer (aka hidden_states). + - a1: (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - - topk_weights (torch.Tensor): The topk weights applied at the end of the layer. + - topk_weights (torch.Tensor): The topk weights applied at the end of + the layer. - topk_ids (torch.Tensor): A map of row to expert id. - inplace (bool): If True, perform the operation in-place. - Defaults to False. + Defaults to False. - activation (str): The activation function to apply after the first - MoE layer. + MoE layer. - global_num_experts (int): The total number of experts in the global - expert space. + expert space. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert - parallel shard. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for - w1. + w1. - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for - w2. - - a1_scale (Optional[torch.Tensor]): Optional scale to be used for - a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for - a2. + w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -255,15 +248,8 @@ def forward( output = a1 if inplace else torch.empty_like(a1) workspace13_shape, workspace2_shape, workspace_dtype = ( - self.fused_experts.workspace_shapes( - a1.dtype, - M, - N, - K, - top_k, - global_num_experts - ) - ) + self.fused_experts.workspace_shapes(a1.dtype, M, N, K, top_k, + global_num_experts)) # We can reuse the memory between cache1 and cache3 because by the time # we need cache3, we're done with cache1 @@ -301,6 +287,7 @@ def forward( workspace2=workspace2, ) - self.dispatch_combine.combine(output, fused_out, topk_weights, topk_ids) + self.dispatch_combine.combine(output, fused_out, topk_weights, + topk_ids) return output diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 60e1877ad865..93a3d8ab9c18 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -1,6 +1,8 @@ -import torch +# SPDX-License-Identifier: Apache-2.0 from typing import Optional, Tuple +import torch + from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 3bc6b50720cb..7219ea2c0a31 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -1,7 +1,9 @@ -import torch +# SPDX-License-Identifier: Apache-2.0 from typing import List, Optional, Tuple import pplx_kernels as pplx +import torch + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize @@ -11,14 +13,14 @@ # as the ones used to create the AllToAll. Unfortunately, there's # no way(?) to extract this info from AllToAll class PplxDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__( - self, - a2a: pplx.AllToAll, - max_num_tokens: int, - world_size: int, - dp_size: int, - quant_dtype: Optional[torch.dtype] = None, - block_shape: Optional[List[int]] = None): + + def __init__(self, + a2a: pplx.AllToAll, + max_num_tokens: int, + world_size: int, + dp_size: int, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[List[int]] = None): super().__init__() self.a2a = a2a self.block_shape = block_shape @@ -37,7 +39,8 @@ def dispatch( device = a1.device if self.quant_dtype == torch.float8_e4m3fn: - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + per_act_token = a1_scale.numel( + ) != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) a1q, a1q_scale = _fp8_quantize( @@ -65,7 +68,8 @@ def dispatch( expert_x_scale: torch.Tensor | None = None if a1q.dtype.itemsize == 1: float32_size = torch.float32.itemsize - block_size = (self.block_shape[0] if self.block_shape is not None else 1) * float32_size + block_size = (self.block_shape[0] if self.block_shape is not None + else 1) * float32_size expert_x_scale = torch.empty( ( num_experts, @@ -77,7 +81,9 @@ def dispatch( ) # This argument is optional - bound_m = torch.tensor([a1q.shape[0]], dtype=torch.uint32, device=device) + bound_m = torch.tensor([a1q.shape[0]], + dtype=torch.uint32, + device=device) self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, @@ -98,14 +104,14 @@ def combine( topk_ids: torch.Tensor, ) -> None: # This argument is optional - bound_m = torch.tensor([output.shape[0]], dtype=torch.uint32, device=output.device) + bound_m = torch.tensor([output.shape[0]], + dtype=torch.uint32, + device=output.device) # TODO assert output is the proper size - self.a2a.combine( - out_tokens=output, - indices=topk_ids, - weights=topk_weights, - expert_y=fused_expert_output, - bound_m=bound_m - ) + self.a2a.combine(out_tokens=output, + indices=topk_ids, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index ee8e8857fabd..152007d42169 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -31,10 +31,7 @@ def _fp8_quantize( """ if block_shape is None: A, A_scale = ops.scaled_fp8_quant( - A, - A_scale, - use_per_token_if_dynamic=per_act_token - ) + A, A_scale, use_per_token_if_dynamic=per_act_token) else: assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] From b136032bc0c5d7c3a2bf8c1cce649e68c64c6ff8 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 21:04:11 +0000 Subject: [PATCH 017/171] comments Signed-off-by: Bill Nell --- .../layers/fused_moe/modular_kernel.py | 53 ++++++++++++------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 196c29eca8a8..1b084b198f3c 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,6 +4,24 @@ import torch +# +# This file defines a set of base classes used to make MoE kernels more modular. +# The goal is to be able to utilize different communication mechanisms with +# any fused MoE kernel without needing to have combinatoric implementations. +# +# Break the fused moe layer down into the following components. Each component +# will be independent of the others except for [Quantize-Dispatch] and +# [Combine]. The components can then be mixed and matched with different fused +# moe kernels so that DP+EP can be supported easily for multiple MoE +# implementations. +# +# Architecture: +# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine] +# +# [Quantize-Dispatch] and [Combine] functionality are bundled into a single +# class `FusedMoEQuantizeDispatchCombine` since they could use collective +# communication mechanisms that need to be consistent. +# def moe_problem_size( a1: torch.Tensor, @@ -34,23 +52,10 @@ def moe_problem_size( return E, M, N, K, topk -# -# A set of base classes used to make MoE kernels more modular. -# -# Architecture: -# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine] -# -# [Quantize-Dispatch] and [Combine] functionality are bundled into a single -# class `FusedMoEQuantizeDispatchCombine` since they could use collective -# communication mechanisms that need to be consistent. -# -# Ideal architecture: -# [Router] → [Quantize-Dispatch-Permute] → [Experts] → [Unpermute-Combine] -# - - class FusedMoEQuantizeDispatchCombine(ABC): """ + An abstract base class for the [Quantize-Dispatch] and [Combine] steps + described above. """ @abstractmethod @@ -102,6 +107,10 @@ def combine( class FusedMoEPermuteExpertsUnpermute(ABC): + """ + An abstract base class for the [Permute-Experts-Unpermute] step described + above. + """ @abstractmethod def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, @@ -177,10 +186,18 @@ def apply( raise NotImplementedError -# Note: only intended for use with a single model layer (due to temp buffers, -# constants, etc.) -class FusedMoEModularKernel(torch.nn.Module): # should this be a module? +class FusedMoEModularKernel(torch.nn.Module): + """ + This class combines a FusedMoEQuantizeDispatchCombine instance and + a FusedMoEPermuteExpertsUnpermute to provide an interface that + is compatible with the `fused_experts` function in fused_moe.py. + + It takes care of managing any required scratch space. + Note: Instances of this class should only be used for a single model + layer due to any layer specific state that may be used by the component + objects. + """ def __init__( self, dispatch_combine: FusedMoEQuantizeDispatchCombine, From 62584bf1dcf64c4fbe9b589c2d111457310ce52e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 21:18:30 +0000 Subject: [PATCH 018/171] fix linter Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 5 ++++- vllm/model_executor/layers/fused_moe/modular_kernel.py | 2 +- .../model_executor/layers/fused_moe/pplx_dispatch_combine.py | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index c6b50729b246..ea903e2500a7 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -206,10 +206,12 @@ def workspace_shapes( self, a_dtype: torch.dtype, M: int, - K: int, # Note that K, N are transposed N: int, + K: int, topk: int, num_experts: int) -> Tuple[int, int, torch.dtype]: + # Note that K, N are transposed + N, K = K, N workspace1 = M * topk * max(2 * N, K) workspace2 = M * topk * N return (workspace1, workspace2, self.out_dtype) @@ -240,6 +242,7 @@ def apply( assert w1.shape[1] == K assert global_num_experts != -1 + assert a1q_scale is not None per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 1b084b198f3c..f56790d4dcc3 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -28,7 +28,7 @@ def moe_problem_size( w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, -) -> Tuple[int, int, int, int]: +) -> Tuple[int, int, int, int, int]: """ Extract the MoE problem size from the given tensor arguments: - a: The hidden states, input to the MoE layer. diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 7219ea2c0a31..fc5ff1ae0209 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -25,6 +25,7 @@ def __init__(self, self.a2a = a2a self.block_shape = block_shape self.dp_num_tokens = max_num_tokens * (world_size // dp_size) + self.quant_dtype = quant_dtype def dispatch( self, From cbdc4710b3f1c098e696196eec3da0c916a6eafb Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 21:26:18 +0000 Subject: [PATCH 019/171] fix more linter stuff Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 11 +++-------- .../model_executor/layers/fused_moe/modular_kernel.py | 5 ++++- .../layers/fused_moe/pplx_dispatch_combine.py | 2 +- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index ea903e2500a7..64e6d425bde2 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -202,14 +202,9 @@ def __init__( self.c_strides2 = c_strides2 self.out_dtype = out_dtype - def workspace_shapes( - self, - a_dtype: torch.dtype, - M: int, - N: int, - K: int, - topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, + topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: # Note that K, N are transposed N, K = K, N workspace1 = M * topk * max(2 * N, K) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index f56790d4dcc3..5db49a630a4a 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -23,6 +23,7 @@ # communication mechanisms that need to be consistent. # + def moe_problem_size( a1: torch.Tensor, w1: torch.Tensor, @@ -43,7 +44,8 @@ def moe_problem_size( of w1 or w2. Similarly, some kernels transpose the weights, so this needs to be kept in mind. """ - # Make sure we are using the correct a1 (pre-permute) + + # Make sure we are using the correct a1 (pre-permute). assert topk_ids.shape[0] == a1.shape[0] M, _ = a1.shape E, N, _ = w1.shape @@ -198,6 +200,7 @@ class FusedMoEModularKernel(torch.nn.Module): layer due to any layer specific state that may be used by the component objects. """ + def __init__( self, dispatch_combine: FusedMoEQuantizeDispatchCombine, diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index fc5ff1ae0209..5c844ff57a76 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -66,7 +66,7 @@ def dispatch( device=device, ) - expert_x_scale: torch.Tensor | None = None + expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: float32_size = torch.float32.itemsize block_size = (self.block_shape[0] if self.block_shape is not None From 8e2c5b26558c8c3ebcaae66c822caf2cbaa44a48 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 23:20:14 +0000 Subject: [PATCH 020/171] cleanup for review Signed-off-by: Bill Nell --- tests/kernels/quantization/test_block_fp8.py | 20 +- tests/kernels/test_cutlass_moe.py | 102 ++----- .../layers/fused_moe/cutlass_moe.py | 74 ++++- .../layers/fused_moe/deep_gemm_moe.py | 254 +++++------------- .../layers/fused_moe/fused_moe.py | 51 +--- .../layers/fused_moe/modular_kernel.py | 28 +- 6 files changed, 199 insertions(+), 330 deletions(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 6691d28c83d8..ea2df2230aa6 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) + _valid_deep_gemm_shape, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -424,21 +424,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - if True: - dgm = modular_deep_gemm_fused_moe_fp8() - - def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids): - return dgm(a, - w1, - w2, - topk_weights, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s) - else: - deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 - # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): if M >= 128: @@ -450,8 +435,7 @@ def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) + out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index 0dc572c72885..3cfed6ae8538 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Optional - import pytest import torch from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8, modular_cutlass_moe_fp8) +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, fused_topk) from vllm.platforms import current_platform @@ -16,48 +13,6 @@ TOP_KS = [6, 8] -def get_cutlass_moe_fp8(ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, - out_dtype=torch.half) -> Callable: - if True: - return modular_cutlass_moe_fp8( - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - out_dtype, - ) - else: - - def cutlass_moe_fp8_fn( - a: torch.Tensor, - w1_q: torch.Tensor, - w2_q: torch.Tensor, - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - a1_scale: Optional[torch.Tensor], - ) -> torch.Tensor: - return cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale, - out_dtype=out_dtype) - - return cutlass_moe_fp8_fn - - def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, @@ -66,22 +21,18 @@ def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - - cutlass_moe_fp8_fn = get_cutlass_moe_fp8( - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - ) - - return cutlass_moe_fp8_fn(a, - w1_q, - w2_q, - w1_scale=w1_scale, - w2_scale=w2_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - a1_scale=a_scale) + return cutlass_moe_fp8(a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale) @pytest.mark.parametrize("m", [2, 64, 224]) @@ -167,21 +118,18 @@ def test_cutlass_moe_no_graph( triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) - cutlass_moe_fp8_fn = get_cutlass_moe_fp8( - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - ) - - cutlass_output = cutlass_moe_fp8_fn(a, - w1_q, - w2_q, - w1_scale=w1_scale, - w2_scale=w2_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - a1_scale=a_scale1) + cutlass_output = cutlass_moe_fp8(a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale1) #print(triton_output) #print(cutlass_output) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 64e6d425bde2..669505c656d8 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -229,7 +229,6 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, ) -> torch.Tensor: - # TODO: chunking in here or in FusedMoEModularKernel? ignore for now M = a1q.shape[0] _, N, K = w2.shape # because w1 + w2 are transposed topk = topk_ids.shape[1] @@ -311,3 +310,76 @@ def modular_cutlass_moe_fp8( out_dtype, ), ) + + +#TODO make the grouped gemm kernel consistent with scaled gemm kernel +def cutlass_moe_fp8( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.half, +) -> torch.Tensor: + """ + This function computes a a8w8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with CUTLASS + grouped gemm. + + Parameters: + - a (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. + Shape: [num_experts, K, 2N] (the weights are passed transposed) + - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. + Shape: [num_experts, N, K] (the weights are passed transposed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts] or [num_experts, K] + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - ab_strides1 (torch.Tensor): The input and weights strides of the first + grouped gemm. + - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. + - ab_strides2 (torch.Tensor): The input and weights strides of the second + grouped gemm. + - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [M] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [M] + - out_dtype (torch.dtype): The output tensor type. + + Returns: + - torch.Tensor: The fp16 output tensor after applying the MoE layer. + """ + fn = modular_cutlass_moe_fp8( + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + out_dtype, + ) + return fn( + a, + w1_q, + w2_q, + topk_weights, + topk_ids, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 6ffb40cb52cb..b19d1f52fa4a 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -4,13 +4,12 @@ import torch -import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.dispatch_combine import ( StandardDispatchCombine) from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - _moe_permute, _moe_unpermute_and_reduce) + _moe_permute) from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) from vllm.utils import round_up @@ -58,186 +57,6 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, and w2.is_contiguous()) -def deep_gemm_moe_fp8( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - This function computes a a8w8-quantized Mixture of Experts (MoE) layer - using two sets of quantized weights, w1_q and w2_q, and top-k gating - mechanism. The matrix multiplications are implemented with DeepGemm - grouped gemm. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - Shape: [M, K] - - w1 (torch.Tensor): The first set of fp8 quantized expert weights. - Shape: [num_experts, K, 2N] (the weights are passed transposed) - - w2 (torch.Tensor): The second set of fp8 quantized expert weights. - Shape: [num_experts, N, K] (the weights are passed transposed) - - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. - Shape: [num_experts] or [num_experts, 2N] - - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - Shape: [num_experts] or [num_experts, K] - - topk_weights (torch.Tensor): The weights of each token->expert mapping. - - topk_ids (torch.Tensor): The token->expert mapping for topk_weights. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - activation (str): The activation function to apply after the first - MoE layer. - - global_num_experts (int): The total number of experts in the global - expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert - parallel shard. - - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. - Shape: scalar or [M] - - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to - quantize the intermediate result between the gemms. - Shape: scalar or [M] - - Returns: - - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. - """ - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - - assert expert_map is None, "Expert maps not supported yet" - - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.stride(-1) == 1, "Stride of last dimension must be 1" - assert w2.stride(-1) == 1, "Stride of last dimension must be 1" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - assert w1.dtype == torch.float8_e4m3fn - assert w2.dtype == torch.float8_e4m3fn - assert w1.shape[0] == w2.shape[0], "Expert number mismatch" - assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" - assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" - assert a1_scale is None or a1_scale.dim( - ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[ - 0] == hidden_states.shape[0], "Input scale shape mismatch" - assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 - - num_tokens, _ = hidden_states.shape - E, N, _ = w1.shape - K = w2.shape[1] - if global_num_experts == -1: - global_num_experts = E - - # We execute the fused_moe kernel in chunks to circumvent this issue: - # https://github.com/vllm-project/vllm/issues/5938 - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE - - assert _valid_deep_gemm(hidden_states, w1, w2, expert_map) - - if inplace: - out_hidden_states = hidden_states - else: - out_hidden_states = torch.empty_like(hidden_states) - - block_m = dg.get_m_alignment_for_contiguous_layout() - block_shape = [block_m, block_m] - - assert w1_scale is not None - assert w2_scale is not None - - # We attempt to transpose and align offline in Fp8MoEMethod, in which - # case these calls will be nops. Otherwise, they'll be performed every - # time the layer is executed. - w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() - w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - - M_sum = topk_ids.numel() + global_num_experts * (block_m - 1) - M_sum = round_up(M_sum, block_m) - - num_chunks = (num_tokens // CHUNK_SIZE) + 1 - - # We can reuse the memory between cache1 and cache3 because by the time - # we need cache3, we're done with cache1 - workspace13 = torch.empty(M_sum * max(N, K), - device=hidden_states.device, - dtype=hidden_states.dtype) - - workspace1 = workspace13[:M_sum * N].view(M_sum, N) - workspace2 = torch.empty((M_sum, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - workspace3 = workspace13[:M_sum * K].view(M_sum, K) - - for chunk in range(num_chunks): - begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, - num_tokens)) - curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] - tokens_in_chunk, _ = curr_hidden_states.shape - - if tokens_in_chunk == 0: - break - - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - - a1q_scale: Optional[torch.Tensor] = None - - qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states, - a1_scale, block_shape) - - (qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale, - curr_topk_ids, global_num_experts, - expert_map, block_m) - - # Adjust the intermediate cache size and config for the last chunk. - # Note that in most cases we only have one chunk so the cache size - # and config are already set correctly and do not need to be adjusted. - if tokens_in_chunk < CHUNK_SIZE and chunk > 0: - curr_M = sorted_token_ids.numel() - workspace1 = _resize_cache(workspace1, (curr_M, N)) - workspace2 = _resize_cache(workspace2, (curr_M, N // 2)) - workspace3 = _resize_cache(workspace3, (curr_M, K)) - - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1, - expert_ids) - - if activation == "silu": - torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N)) - elif activation == "gelu": - torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N)) - else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}") - - a2q_scale: Optional[torch.Tensor] = None - - qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale, - block_shape) - - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids) - - _moe_unpermute_and_reduce( - out_hidden_states[begin_chunk_idx:end_chunk_idx], - workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights) - - return out_hidden_states - - class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self): @@ -274,7 +93,6 @@ def apply( ) -> torch.Tensor: import deep_gemm as dg - # TODO: chunking in here or in FusedMoEModularKernel? ignore for now _, N, K = w1.shape assert global_num_experts != -1 @@ -323,3 +141,73 @@ def modular_deep_gemm_fused_moe_fp8() -> mk.FusedMoEModularKernel: block_shape=deep_gemm_block_shape()), DeepGemmExperts(), ) + + +def deep_gemm_moe_fp8( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a a8w8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with DeepGemm + grouped gemm. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1 (torch.Tensor): The first set of fp8 quantized expert weights. + Shape: [num_experts, K, 2N] (the weights are passed transposed) + - w2 (torch.Tensor): The second set of fp8 quantized expert weights. + Shape: [num_experts, N, K] (the weights are passed transposed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts] or [num_experts, K] + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - topk_ids (torch.Tensor): The token->expert mapping for topk_weights. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [M] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [M] + + Returns: + - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. + """ + fn = modular_deep_gemm_fused_moe_fp8() + return fn( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace, + activation, + global_num_experts, + expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index daae08f78b6e..ab23b9ff2437 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1157,30 +1157,6 @@ def fused_experts(hidden_states: torch.Tensor, a1_scale=a1_scale, a2_scale=a2_scale, ) - elif hidden_states.shape[0] <= envs.VLLM_FUSED_MOE_CHUNK_SIZE: - fe = modular_triton_fused_moe( - use_fp8_w8a8, - use_int8_w8a16, - use_int4_w4a16, - block_shape, - ) - return fe( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace, - activation, - global_num_experts, - expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, - a1_scale, - a2_scale, - ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, @@ -1634,19 +1610,17 @@ def apply( global_num_experts = E top_k_num = topk_ids.shape[1] - # We execute the fused_moe kernel in chunks to circumvent this issue: - # https://github.com/vllm-project/vllm/issues/5938 config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, dtype=hidden_states.dtype) - get_config_func = functools.partial( - try_get_optimal_moe_config, + config = try_get_optimal_moe_config( w1.shape, w2.shape, top_k_num, config_dtype, + num_tokens, block_shape=self.block_shape, ) @@ -1662,29 +1636,20 @@ def apply( raise ValueError( f"Unsupported compute_type: {hidden_states.dtype}") - curr_hidden_states = hidden_states - tokens_in_chunk, _ = curr_hidden_states.shape - # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 intermediate_cache1 = _resize_cache(workspace13, - (tokens_in_chunk, top_k_num, N)) - intermediate_cache2 = _resize_cache( - workspace2, (tokens_in_chunk * top_k_num, N // 2)) + (num_tokens, top_k_num, N)) + intermediate_cache2 = _resize_cache(workspace2, + (num_tokens * top_k_num, N // 2)) intermediate_cache3 = _resize_cache(workspace13, - (tokens_in_chunk, top_k_num, K)) - - config = get_config_func(tokens_in_chunk) - - curr_topk_ids = topk_ids - - qcurr_hidden_states, a1q_scale = hidden_states, a1q_scale + (num_tokens, top_k_num, K)) sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) - invoke_fused_moe_kernel(qcurr_hidden_states, + invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, a1q_scale, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 5db49a630a4a..2dcbf0dd3415 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -9,22 +9,34 @@ # The goal is to be able to utilize different communication mechanisms with # any fused MoE kernel without needing to have combinatoric implementations. # -# Break the fused moe layer down into the following components. Each component -# will be independent of the others except for [Quantize-Dispatch] and -# [Combine]. The components can then be mixed and matched with different fused -# moe kernels so that DP+EP can be supported easily for multiple MoE -# implementations. +# The fused moe kernels are broken down into the following components: # -# Architecture: # [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine] # +# Each component will be independent of the others except for +# [Quantize-Dispatch] and `[Combine] (see below). The components can then be +# mixed and matched with so that DP+EP can be supported easily for multiple +# MoE kernel implementations. +# +# The following main classes are defined: +# * FusedMoEQuantizeDispatchCombine - an abstract base class for quantization, +# dispatching and combing. The dispatch method takes care of any needed +# quantization and the combine method applies weights and does the final +# reduction of the output. +# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused +# MoE operation. One important feature to note is that this class does not +# apply topk weights or reduce the final output. +# * FusedMoEModularKernel - an interface class that combines a +# FusedMoEQuantizeDispatchCombine and a FusedMoEPermuteExpertsUnpermute to +# provide the standard fused MoE kernel interface. +# # [Quantize-Dispatch] and [Combine] functionality are bundled into a single # class `FusedMoEQuantizeDispatchCombine` since they could use collective # communication mechanisms that need to be consistent. # -def moe_problem_size( +def _moe_problem_size( a1: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -260,7 +272,7 @@ def forward( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ - E, M, N, K, top_k = moe_problem_size(a1, w1, w2, topk_ids) + E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) if global_num_experts == -1: global_num_experts = E From cef98ab82de4965e8671135110fc4c7bc559026a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Apr 2025 01:51:15 +0000 Subject: [PATCH 021/171] review comments Signed-off-by: Bill Nell --- .../layers/fused_moe/deep_gemm_moe.py | 16 ++++++++++++---- .../layers/fused_moe/pplx_dispatch_combine.py | 4 +++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index b19d1f52fa4a..250f03ae7f08 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -31,7 +31,6 @@ def _valid_deep_gemm_shape(M: int, N: int, K: int): return align <= M and N % align == 0 and K % align == 0 -# TODO: check types? def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -42,19 +41,28 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, aligned by `dg.get_m_alignment_for_contiguous_layout()`. """ if not has_deep_gemm: + logger.debug("DeepGemm disabled: deep_gemm not available.") return False - # Expert maps not supported yet. if expert_map is not None: + logger.debug("DeepGemm disabled: expert map NYI.") return False M = hidden_states.shape[0] _, K, N = w2.shape if not _valid_deep_gemm_shape(M, N, K): + logger.debug("DeepGemm disabled: unalinged problem size.") return False - return (hidden_states.is_contiguous() and w1.is_contiguous() - and w2.is_contiguous()) + if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): + logger.debug("DeepGemm disabled: invalid weight dtype(s).") + return False + + if (not hidden_states.is_contiguous() or not w1.is_contiguous() + or not w2.is_contiguous()): + logger.debug( + "DeepGemm disabled: weights or activations not contiguous.") + return False class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 5c844ff57a76..936aee14a7bc 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -24,6 +24,7 @@ def __init__(self, super().__init__() self.a2a = a2a self.block_shape = block_shape + self.max_num_tokens = max_num_tokens self.dp_num_tokens = max_num_tokens * (world_size // dp_size) self.quant_dtype = quant_dtype @@ -109,7 +110,8 @@ def combine( dtype=torch.uint32, device=output.device) - # TODO assert output is the proper size + assert output.shape[0] == self.max_num_tokens + assert output.shape[1] == fused_expert_output.shape[-1] self.a2a.combine(out_tokens=output, indices=topk_ids, From 13da7eab8e54b64322f814e84efe478258aaeeb5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Apr 2025 01:58:16 +0000 Subject: [PATCH 022/171] forgot return Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 250f03ae7f08..e9adb335355d 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -64,6 +64,8 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, "DeepGemm disabled: weights or activations not contiguous.") return False + return True + class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): From fb39d5072f932dbbf5ac472f0fc05f06a372bf47 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Apr 2025 15:08:02 +0000 Subject: [PATCH 023/171] add dp_rank_num_tokens to DPMetadata Signed-off-by: Bill Nell --- vllm/forward_context.py | 7 ++++++- .../layers/fused_moe/pplx_dispatch_combine.py | 9 +++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c75d8f088c5b..ea3aaa66b4cf 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -32,6 +32,7 @@ @dataclass class DPMetadata: cu_tokens_across_dp_cpu: torch.Tensor + dp_rank_num_tokens: torch.Tensor @dataclass @@ -90,7 +91,11 @@ def set_forward_context(attn_metadata: Any, from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) - dp_metadata = DPMetadata(cu_tokens_across_dp_cpu) + dp_rank_num_tokens = torch.tensor( + [num_tokens], + dtype=torch.uint32, + device=vllm_config.device_config.device) + dp_metadata = DPMetadata(cu_tokens_across_dp_cpu, dp_rank_num_tokens) global _forward_context prev_context = _forward_context diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 936aee14a7bc..d35cfaccd39d 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -5,6 +5,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize @@ -83,9 +84,7 @@ def dispatch( ) # This argument is optional - bound_m = torch.tensor([a1q.shape[0]], - dtype=torch.uint32, - device=device) + bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, @@ -106,9 +105,7 @@ def combine( topk_ids: torch.Tensor, ) -> None: # This argument is optional - bound_m = torch.tensor([output.shape[0]], - dtype=torch.uint32, - device=output.device) + bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens assert output.shape[0] == self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] From 9bac87a5d6f78097d557f6170bb2735a08eb2150 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Apr 2025 22:29:28 +0000 Subject: [PATCH 024/171] better check for fp8 in _fp8_permute Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 152007d42169..93f9158f15ef 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -44,7 +44,7 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: """ A permutation routine that works on fp8 types. """ - if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: + if torch.is_floating_point(m) and m.dtype.itemsize == 1: return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) else: return m[idx, ...] From 9882f970cf724cb90f82b9526b9a7b5145b17a34 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 28 Apr 2025 18:38:48 +0000 Subject: [PATCH 025/171] updates Signed-off-by: Bill Nell --- .../layers/fused_moe/deep_gemm_moe.py | 22 ++++---- .../layers/fused_moe/dispatch_combine.py | 4 +- .../layers/fused_moe/fused_moe.py | 48 ++++++++++-------- .../layers/fused_moe/modular_kernel.py | 50 ++++++++++++++----- vllm/model_executor/layers/fused_moe/utils.py | 2 +- 5 files changed, 79 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index e9adb335355d..e43c984f7d5f 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -73,15 +73,21 @@ def __init__(self): super().__init__() self.block_shape = deep_gemm_block_shape() - def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, - topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) workspace1 = M_sum * max(N * 2, K) workspace2 = M_sum * N - return (workspace1, workspace2, a_dtype) + return (workspace1, workspace2, a.dtype) def apply( self, @@ -100,6 +106,7 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: import deep_gemm as dg @@ -126,12 +133,7 @@ def apply( dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids) - if activation == "silu": - torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N)) - elif activation == "gelu": - torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N)) - else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}") + self.activation(activation, workspace2, workspace1.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/dispatch_combine.py b/vllm/model_executor/layers/fused_moe/dispatch_combine.py index 06b90c350252..398aab60c660 100644 --- a/vllm/model_executor/layers/fused_moe/dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/dispatch_combine.py @@ -26,7 +26,7 @@ def dispatch( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: if self.quant_dtype == torch.float8_e4m3fn: per_act_token = a1_scale.numel( ) != 1 if a1_scale is not None else ( @@ -42,7 +42,7 @@ def dispatch( a1q = a1 a1q_scale = a1_scale - return a1q, a1q_scale + return a1q, a1q_scale, None def combine( self, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ab23b9ff2437..5cd42fa8b2e0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1262,7 +1262,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert hidden_states.shape[1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2], \ + f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" @@ -1272,7 +1273,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, torch.float32, torch.float16, torch.bfloat16 ] - num_tokens, _ = hidden_states.shape + num_tokens = hidden_states.shape[0] E, N, _ = w1.shape K = w2.shape[1] if global_num_experts == -1: @@ -1554,20 +1555,28 @@ def __init__( use_fp8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, - block_shape: Optional[List[int]], + block_shape: Optional[List[int]] = None, + block_m: Optional[int] = None, ): super().__init__() self.use_fp8_w8a8 = use_fp8_w8a8 self.use_int4_w4a16 = use_int4_w4a16 self.use_int8_w8a16 = use_int8_w8a16 self.block_shape = block_shape + self.block_m = block_m - def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, - topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: workspace1 = M * topk * max(N * 2, K) workspace2 = M * topk * N - return (workspace1, workspace2, a_dtype) + return (workspace1, workspace2, a.dtype) def apply( self, @@ -1586,14 +1595,16 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: # Check constraints. if self.use_int4_w4a16: - assert hidden_states.shape[1] // 2 == w1.shape[ + assert hidden_states.shape[-1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: - assert hidden_states.shape[1] == w1.shape[ - 2], "Hidden size mismatch" + assert hidden_states.shape[-1] == w1.shape[2], \ + (f"Hidden size mismatch {hidden_states.shape[-1]} " + f"!= {w1.shape[2]}") assert hidden_states.is_contiguous( ), "Hidden_states must be contiguous" @@ -1603,12 +1614,11 @@ def apply( torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn ] - num_tokens, _ = hidden_states.shape - E, N, _ = w1.shape - K = w2.shape[1] + E, num_tokens, N, K, top_k_num = mk._moe_problem_size( + hidden_states, w1, w2, topk_ids) + if global_num_experts == -1: global_num_experts = E - top_k_num = topk_ids.shape[1] config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, use_int8_w8a16=self.use_int8_w8a16, @@ -1668,14 +1678,8 @@ def apply( use_int4_w4a16=self.use_int4_w4a16, block_shape=self.block_shape) - if activation == "silu": - torch.ops._C.silu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) - elif activation == "gelu": - torch.ops._C.gelu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) - else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}") + self.activation(activation, intermediate_cache2, + intermediate_cache1.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 2dcbf0dd3415..b517f6ee13c5 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -56,13 +56,19 @@ def _moe_problem_size( of w1 or w2. Similarly, some kernels transpose the weights, so this needs to be kept in mind. """ - - # Make sure we are using the correct a1 (pre-permute). - assert topk_ids.shape[0] == a1.shape[0] - M, _ = a1.shape + assert w1.dim() == 3 and w2.dim() == 3 E, N, _ = w1.shape K = w2.shape[1] + + assert a1.dim() == 2 + assert topk_ids.dim() == 2 + # Make sure we are using the correct a1 (pre-permute). + assert topk_ids.shape[0] == a1.shape[ + 0], f"{topk_ids.shape[0]} != {a1.shape[0]}" + + M = a1.shape[0] topk = topk_ids.shape[1] + return E, M, N, K, topk @@ -81,7 +87,7 @@ def dispatch( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform any quantization (and/or) dispatching needed for this kernel. @@ -127,9 +133,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC): """ @abstractmethod - def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, - topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: """ Compute the number of elements for the temporary outputs of the two gemms and activation in the fused expert function. Since the @@ -145,6 +157,15 @@ def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, """ raise NotImplementedError + def activation(self, activation: str, output: torch.Tensor, + input: torch.Tensor) -> None: + if activation == "silu": + torch.ops._C.silu_and_mul(output, input) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(output, input) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + @abstractmethod def apply( self, @@ -163,6 +184,7 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: """ This function computes the intermediate result of a Mixture of Experts @@ -193,6 +215,8 @@ def apply( must be large enough to hold output of either MoE gemm. - workspace2 (torch.Tensor): A scratch tensor used for the activation function. + - expert_num_tokens: An optional tensor containing the number of tokens + assigned to each expert when using batched experts format input. Returns: - torch.Tensor: The unweighted, unreduced output tensor @@ -224,7 +248,7 @@ def __init__( def forward( self, - a1: torch.Tensor, + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, @@ -245,7 +269,7 @@ def forward( of weights, w1 and w2, and top-k gating mechanism. Parameters: - - a1: (torch.Tensor): The input tensor to the MoE layer. + - hidden_states: (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_weights (torch.Tensor): The topk weights applied at the end of @@ -272,6 +296,7 @@ def forward( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ + a1 = hidden_states E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) if global_num_experts == -1: @@ -280,7 +305,7 @@ def forward( output = a1 if inplace else torch.empty_like(a1) workspace13_shape, workspace2_shape, workspace_dtype = ( - self.fused_experts.workspace_shapes(a1.dtype, M, N, K, top_k, + self.fused_experts.workspace_shapes(a1, M, N, K, top_k, global_num_experts)) # We can reuse the memory between cache1 and cache3 because by the time @@ -292,7 +317,7 @@ def forward( device=a1.device, dtype=workspace_dtype) - a1q, a1q_scale = self.dispatch_combine.dispatch( + a1q, a1q_scale, expert_num_tokens = self.dispatch_combine.dispatch( a1, a1_scale, a2_scale, @@ -317,6 +342,7 @@ def forward( a2_scale, workspace13=workspace13, workspace2=workspace2, + expert_num_tokens=expert_num_tokens, ) self.dispatch_combine.combine(output, fused_out, topk_weights, diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 93f9158f15ef..eff39c0f7922 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -15,7 +15,7 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: Shrink the given tensor and apply the given view to it. This is used to resize the intermediate fused_moe caches. """ - assert prod(v) <= x.numel() + assert prod(v) <= x.numel(), f"{prod(v)} <= {x.numel()}" return x.flatten()[:prod(v)].view(*v) From cfcdb703f2d6954b4629393e734a518bdaeb7f89 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 02:06:11 +0000 Subject: [PATCH 026/171] fix merge issues Signed-off-by: Bill Nell --- tests/kernels/moe/test_cutlass_moe.py | 22 +- tests/kernels/moe/test_moe.py | 44 +-- tests/kernels/quantization/test_block_fp8.py | 29 +- tests/kernels/quantization/test_block_int8.py | 5 +- tests/kernels/test_cutlass_moe.py | 244 --------------- .../layers/fused_moe/cutlass_moe.py | 286 ++++++------------ .../layers/fused_moe/deep_gemm_moe.py | 8 +- .../layers/fused_moe/dispatch_combine.py | 43 +-- .../layers/fused_moe/fused_moe.py | 148 ++++----- .../layers/fused_moe/modular_kernel.py | 44 +-- .../layers/fused_moe/moe_permute_unpermute.py | 4 +- .../layers/fused_moe/pplx_dispatch_combine.py | 41 ++- vllm/model_executor/layers/fused_moe/utils.py | 48 ++- 13 files changed, 346 insertions(+), 620 deletions(-) delete mode 100644 tests/kernels/test_cutlass_moe.py diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 975cd418a171..7d24307e353a 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -30,6 +30,11 @@ (224, 3072, 1536), ] +vllm_config = VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1)) +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + @dataclasses.dataclass class MOETensors: @@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit, 'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] 'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] 'topk_weights': topk_weights, - 'topk_ids_': topk_ids, + 'topk_ids': topk_ids, 'ab_strides1': moe_tensors.ab_strides1, 'c_strides1': moe_tensors.c_strides1, 'ab_strides2': moe_tensors.ab_strides2, @@ -231,10 +236,7 @@ def test_cutlass_moe_8_bit_no_graph( per_out_ch: bool, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) @@ -276,10 +278,7 @@ def test_cutlass_moe_8_bit_cuda_graph( per_out_ch: bool, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): dtype = torch.half mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, @@ -334,10 +333,7 @@ def test_cutlass_moe_8_bit_EP( ep_size: int, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_channel) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index cbc20a57cf19..c58ddbb74e38 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -11,9 +11,8 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, - torch_moe, torch_moe_single) -from vllm import _custom_ops as ops +from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, + torch_moe_single) from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk @@ -31,6 +30,10 @@ EP_SIZE = [1, 4] TOP_KS = [2, 6] +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @@ -69,7 +72,6 @@ def test_fused_moe( else: e_map = None - vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): torch_output = torch_moe(a, w1, w2, score, topk, e_map) iterative_output = iterative_moe(a, @@ -196,22 +198,24 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, else: e_map = None - triton_output = fused_moe(a, - w1_qweight, - w2_qweight, - score, - topk, - renormalize=False, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=e, - expert_map=e_map, - w1_scale=w1_scales, - w2_scale=w2_scales, - w1_zp=w1_qzeros if has_zp else None, - w2_zp=w2_qzeros if has_zp else None, - block_shape=[0, group_size]) - torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + with set_current_vllm_config(vllm_config): + triton_output = fused_moe(a, + w1_qweight, + w2_qweight, + score, + topk, + renormalize=False, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + global_num_experts=e, + expert_map=e_map, + w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, + block_shape=[0, group_size]) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index ea2df2230aa6..d0eca89c04e0 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8) + _valid_deep_gemm_shape, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -211,6 +211,9 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): # Set the context to avoid lots of warning spam. vllm_config = VllmConfig() + vllm_config.scheduler_config.max_num_seqs = 128 + vllm_config.scheduler_config.max_model_len = 8192 + with set_current_vllm_config(vllm_config): out = fused_moe( a, @@ -386,8 +389,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): if not _valid_deep_gemm_shape(M, N, K): pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") - vllm_config = VllmConfig() - torch.manual_seed(seed) fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min @@ -424,7 +425,26 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + if True: + dgm = modular_deep_gemm_fused_moe_fp8() + + def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids): + return dgm(a, + w1, + w2, + topk_weights, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s) + else: + deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 + # Set the context to avoid lots of warning spam. + vllm_config = VllmConfig() + vllm_config.scheduler_config.max_num_seqs = 128 + vllm_config.scheduler_config.max_model_len = 8192 + with set_current_vllm_config(vllm_config): if M >= 128: ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, @@ -435,7 +455,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) - out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py index 104f23fd7cd2..a4e9f83f0eaf 100644 --- a/tests/kernels/quantization/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -18,6 +18,10 @@ pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + # For test def native_per_token_group_quant_int8(x, @@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) # Set the context to avoid lots of warning spam. - vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): out = fused_moe( a, diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py deleted file mode 100644 index 3cfed6ae8538..000000000000 --- a/tests/kernels/test_cutlass_moe.py +++ /dev/null @@ -1,244 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import pytest -import torch - -from vllm import _custom_ops as ops -from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, - fused_topk) -from vllm.platforms import current_platform - -NUM_EXPERTS = [40, 64] -TOP_KS = [6, 8] - - -def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, - w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor): - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) - - -@pytest.mark.parametrize("m", [2, 64, 224]) -@pytest.mark.parametrize("n", [1024, 3072]) -@pytest.mark.parametrize("k", [1024, 1536]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("per_act_token", [True, False]) -@pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.skipif( - (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_moe_no_graph( - m: int, - n: int, - k: int, - e: int, - topk: int, - per_act_token: bool, - per_out_ch: bool, -): - current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - - dtype = torch.half - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - # Get the right scale for tests. - _, a_scale1 = ops.scaled_fp8_quant( - a, use_per_token_if_dynamic=per_act_token) - a_q, _ = ops.scaled_fp8_quant(a, - a_scale1, - use_per_token_if_dynamic=per_act_token) - - a_d = a_q.float().mul(a_scale1).to(dtype) - - n_b_scales = 2 * n if per_out_ch else 1 - k_b_scales = k if per_out_ch else 1 - - w1_q = torch.empty((e, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_ch) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_ch) - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - w1_d = torch.empty_like(w1) - w2_d = torch.empty_like(w2) - for expert in range(e): - w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() - w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() - - score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - - triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) - - cutlass_output = cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale1) - - #print(triton_output) - #print(cutlass_output) - #print("*") - - torch.testing.assert_close(triton_output, - cutlass_output, - atol=5e-2, - rtol=1e-2) - - -@pytest.mark.parametrize("m", [2, 64, 224]) -@pytest.mark.parametrize("n", [1024, 3072]) -@pytest.mark.parametrize("k", [1024, 1536]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("per_act_token", [True, False]) -@pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.skipif( - (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_moe_cuda_graph( - m: int, - n: int, - k: int, - e: int, - topk: int, - per_act_token: bool, - per_out_ch: bool, -): - current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - - dtype = torch.half - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - # Get the right scale for tests. - _, a_scale1 = ops.scaled_fp8_quant( - a, use_per_token_if_dynamic=per_act_token) - a_q, _ = ops.scaled_fp8_quant(a, - a_scale1, - use_per_token_if_dynamic=per_act_token) - - a_d = a_q.float().mul(a_scale1).to(dtype) - - n_b_scales = 2 * n if per_out_ch else 1 - k_b_scales = k if per_out_ch else 1 - - w1_q = torch.empty((e, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_ch) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_ch) - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - w1_d = torch.empty_like(w1) - w2_d = torch.empty_like(w2) - for expert in range(e): - w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() - w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() - - score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - - triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) - - stream = torch.cuda.Stream() - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, stream=stream): - cutlass_output = run(a, a_scale1, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, ab_strides1, - c_strides1, ab_strides2, c_strides2) - torch.cuda.synchronize() - graph.replay() - torch.cuda.synchronize() - - #print(triton_output) - #print(cutlass_output) - #print("*") - - torch.testing.assert_close(triton_output, - cutlass_output, - atol=9e-2, - rtol=1e-2) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 669505c656d8..d718ac1f3f3a 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -11,180 +11,6 @@ from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache -#TODO make the grouped gemm kernel consistent with scaled gemm kernel -def cutlass_moe_fp8( - a: torch.Tensor, - w1_q: torch.Tensor, - w2_q: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids_: torch.Tensor, - ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - out_dtype: torch.dtype = torch.half, - expert_map: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, -) -> torch.Tensor: - """ - This function computes a a8w8-quantized Mixture of Experts (MoE) layer - using two sets of quantized weights, w1_q and w2_q, and top-k gating - mechanism. The matrix multiplications are implemented with CUTLASS - grouped gemm. - - Parameters: - - a (torch.Tensor): The input tensor to the MoE layer. - Shape: [M, K] - - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. - Shape: [num_experts, K, 2N] (the weights are passed transposed) - - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. - Shape: [num_experts, N, K] (the weights are passed transposed) - - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. - Shape: [num_experts] or [num_experts, 2N] - - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - Shape: [num_experts] or [num_experts, K] - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk_weights (torch.Tensor): The weights of each token->expert mapping. - - ab_strides1 (torch.Tensor): The input and weights strides of the first - grouped gemm. - - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. - - ab_strides2 (torch.Tensor): The input and weights strides of the second - grouped gemm. - - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. - - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. - Shape: scalar or [M] - - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to - quantize the intermediate result between the gemms. - Shape: scalar or [M] - - out_dtype (torch.dtype): The output tensor type. - - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, - every Rank is responsible for a subset of experts. expert_map is a - mapping from global expert-id to local expert-id. When expert_map[i] - is -1, it means that this Rank is not responsible for global - expert-id i. - - apply_router_weight_on_input (bool): When true, the topk weights are - applied directly on the inputs. This is only applicable when topk is 1. - - Returns: - - torch.Tensor: The fp16 output tensor after applying the MoE layer. - """ - - assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch" - assert w1_q.dtype == torch.float8_e4m3fn - assert w2_q.dtype == torch.float8_e4m3fn - assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" - assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" - assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" - assert a1_scale is None or a1_scale.dim( - ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[ - 0], "Input scale shape mismatch" - assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ - 1] == w1_q.shape[2], "W1 scale shape mismatch" - assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ - 1] == w2_q.shape[2], "W2 scale shape mismatch" - assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch" - assert w1_q.shape[0] == w1_scale.shape[ - 0], "w1 scales expert number mismatch" - assert w1_q.shape[0] == w2_scale.shape[ - 0], "w2 scales expert number mismatch" - assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 - assert ab_strides1.shape[0] == w1_q.shape[ - 0], "AB Strides 1 expert number mismatch" - assert c_strides1.shape[0] == w1_q.shape[ - 0], "C Strides 1 expert number mismatch" - assert ab_strides2.shape[0] == w2_q.shape[ - 0], "AB Strides 2 expert number mismatch" - assert c_strides2.shape[0] == w2_q.shape[ - 0], "C Strides 2 expert number mismatch" - assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" - - num_experts = w1_q.size(0) - m = a.size(0) - k = w1_q.size(1) - n = w2_q.size(1) - - local_topk_ids = topk_ids_ - if expert_map is not None: - "Translate info from expert_map to topk_ids" - local_topk_ids = torch.where(expert_map[topk_ids_] != -1, - expert_map[topk_ids_], -1) - - topk = local_topk_ids.size(1) - - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) - if apply_router_weight_on_input: - assert topk == 1, \ - "apply_router_weight_on_input is only implemented for topk=1" - # TODO: this only works for topK=1, will need to update for topK>1 - a = a * topk_weights.to(out_dtype) - - a_q, a1_scale = ops.scaled_fp8_quant( - a, a1_scale, use_per_token_if_dynamic=per_act_token) - device = a_q.device - - expert_offsets = torch.empty((num_experts + 1), - dtype=torch.int32, - device=device) - problem_sizes1 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - - a_map_initializer = torch.empty - c2_initializer = torch.empty - if expert_map is not None: - # With expert_map each Rank processes only a subset of experts. As - # a result not all of a_map and c2 tensors are filled. We fill it - # zeros for correctness. - a_map_initializer = torch.zeros - c2_initializer = torch.zeros - - a_map = a_map_initializer((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - c_map = torch.empty((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - - ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, num_experts, n, - k) - - rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) - rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale - - c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) - c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype) - - ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, - expert_offsets[:-1], problem_sizes1, ab_strides1, - ab_strides1, c_strides1) - - intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) - torch.ops._C.silu_and_mul(intermediate, c1) - - intemediate_q, a2_scale = ops.scaled_fp8_quant( - intermediate, a2_scale, use_per_token_if_dynamic=per_act_token) - - ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, - expert_offsets[:-1], problem_sizes2, ab_strides2, - ab_strides2, c_strides2) - - # Gather tokens - c2 = c2[c_map].view(m, topk, k) - if not apply_router_weight_on_input: - c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype) - return c2.sum(dim=1) - - class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -202,9 +28,15 @@ def __init__( self.c_strides2 = c_strides2 self.out_dtype = out_dtype - def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, - topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: # Note that K, N are transposed N, K = K, N workspace1 = M * topk * max(2 * N, K) @@ -213,7 +45,7 @@ def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, def apply( self, - a1q: torch.Tensor, + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, @@ -228,16 +60,56 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: + a1q = hidden_states + + assert w1.dtype == torch.float8_e4m3fn + assert w2.dtype == torch.float8_e4m3fn + assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1" + assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2" + assert w1.shape[0] == w2.shape[0], "Expert number mismatch" + assert a1q_scale is None or a1q_scale.dim( + ) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[ + 0], "Input scale shape mismatch" + assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ + 1] == w1.shape[2], "W1 scale shape mismatch" + assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ + 1] == w2.shape[2], "W2 scale shape mismatch" + assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch" + assert w1.shape[0] == w1_scale.shape[ + 0], "w1 scales expert number mismatch" + assert w1.shape[0] == w2_scale.shape[ + 0], "w2 scales expert number mismatch" + assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 + assert self.ab_strides1.shape[0] == w1.shape[ + 0], "AB Strides 1 expert number mismatch" + assert self.c_strides1.shape[0] == w1.shape[ + 0], "C Strides 1 expert number mismatch" + assert self.ab_strides2.shape[0] == w2.shape[ + 0], "AB Strides 2 expert number mismatch" + assert self.c_strides2.shape[0] == w2.shape[ + 0], "C Strides 2 expert number mismatch" + assert self.out_dtype in [torch.half, + torch.bfloat16], "Invalid output dtype" + M = a1q.shape[0] _, N, K = w2.shape # because w1 + w2 are transposed - topk = topk_ids.shape[1] device = a1q.device assert w1.shape[1] == K assert global_num_experts != -1 assert a1q_scale is not None + if expert_map is not None: + "Translate info from expert_map to topk_ids" + local_topk_ids = torch.where(expert_map[topk_ids] != -1, + expert_map[topk_ids], -1) + else: + local_topk_ids = topk_ids + + topk = local_topk_ids.shape[1] + per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) @@ -251,21 +123,29 @@ def apply( dtype=torch.int32, device=device) - a_map = torch.empty((topk_ids.numel()), - dtype=torch.int32, - device=device) - c_map = torch.empty((topk_ids.numel()), + # With expert_map each Rank processes only a subset of experts. As + # a result not all of a_map and c2 tensors are filled. We fill it + # zeros for correctness. + if expert_map is not None: + a_map = torch.zeros((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + else: + a_map = torch.empty((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + + c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device) - ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, - global_num_experts, N, K) + ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, + problem_sizes1, problem_sizes2, a_map, + c_map, global_num_experts, N, K) a1q = _fp8_perm(a1q, a_map) a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale - # fix names c1 = _resize_cache(workspace13, (M * topk, N * 2)) c2 = _resize_cache(workspace2, (M * topk, N)) c3 = _resize_cache(workspace13, (M * topk, K)) @@ -274,16 +154,14 @@ def apply( expert_offsets[:-1], problem_sizes1, self.ab_strides1, self.ab_strides1, self.c_strides1) - if activation == "silu": - torch.ops._C.silu_and_mul(c2, c1) - elif activation == "gelu": - torch.ops._C.gelu_and_mul(c2, c1) - else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}") + self.activation(activation, c2, c1) a2q, a2q_scale = ops.scaled_fp8_quant( c2, a2_scale, use_per_token_if_dynamic=per_act_token) + if expert_map is not None: + c3.fill_(0) + ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets[:-1], problem_sizes2, self.ab_strides2, self.ab_strides2, self.c_strides2) @@ -294,6 +172,7 @@ def apply( def modular_cutlass_moe_fp8( + per_act_token: bool, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, ab_strides2: torch.Tensor, @@ -301,7 +180,10 @@ def modular_cutlass_moe_fp8( out_dtype: torch.dtype = torch.half, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( - StandardDispatchCombine(quant_dtype=torch.float8_e4m3fn), + StandardDispatchCombine( + per_channel_quant=per_act_token, + quant_dtype=torch.float8_e4m3fn, + ), CutlassExperts( ab_strides1, c_strides1, @@ -328,6 +210,8 @@ def cutlass_moe_fp8( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.half, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer @@ -361,25 +245,39 @@ def cutlass_moe_fp8( quantize the intermediate result between the gemms. Shape: scalar or [M] - out_dtype (torch.dtype): The output tensor type. + - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, + every Rank is responsible for a subset of experts. expert_map is a + mapping from global expert-id to local expert-id. When expert_map[i] + is -1, it means that this Rank is not responsible for global + expert-id i. + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is 1. Returns: - torch.Tensor: The fp16 output tensor after applying the MoE layer. """ + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + fn = modular_cutlass_moe_fp8( + per_act_token, ab_strides1, c_strides1, ab_strides2, c_strides2, out_dtype, ) + return fn( a, w1_q, w2_q, topk_weights, topk_ids, + expert_map=expert_map, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index e43c984f7d5f..266ba3bfa07a 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -91,7 +91,7 @@ def workspace_shapes( def apply( self, - a1q: torch.Tensor, + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, @@ -110,6 +110,7 @@ def apply( ) -> torch.Tensor: import deep_gemm as dg + a1q = hidden_states _, N, K = w1.shape assert global_num_experts != -1 @@ -137,7 +138,8 @@ def apply( a2q_scale: Optional[torch.Tensor] = None - a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, self.block_shape) + a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, False, + self.block_shape) dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids) @@ -169,6 +171,7 @@ def deep_gemm_moe_fp8( expert_map: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + apply_router_weight_on_input=False, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer @@ -222,4 +225,5 @@ def deep_gemm_moe_fp8( w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/fused_moe/dispatch_combine.py b/vllm/model_executor/layers/fused_moe/dispatch_combine.py index 398aab60c660..9b647a70d5e0 100644 --- a/vllm/model_executor/layers/fused_moe/dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/dispatch_combine.py @@ -6,15 +6,20 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_unpermute_and_reduce) -from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) class StandardDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__(self, - quant_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): + def __init__( + self, + quant_dtype: Optional[torch.dtype] = None, + per_channel_quant: bool = False, + block_shape: Optional[list[int]] = None, + ): super().__init__() + self.per_channel_quant = per_channel_quant self.block_shape = block_shape self.quant_dtype = quant_dtype @@ -23,24 +28,23 @@ def dispatch( a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - if self.quant_dtype == torch.float8_e4m3fn: - per_act_token = a1_scale.numel( - ) != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) - - a1q, a1q_scale = _fp8_quantize( - a1, - a1_scale, - self.block_shape, - per_act_token, - ) - else: - a1q = a1 - a1q_scale = a1_scale + if apply_router_weight_on_input: + topk = topk_ids.shape[1] + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1.mul_(topk_weights.to(a1.dtype)) + + a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, + self.quant_dtype, + self.per_channel_quant, + self.block_shape) return a1q, a1q_scale, None @@ -50,6 +54,7 @@ def combine( fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, ) -> None: _moe_unpermute_and_reduce(output, fused_expert_output, None, - topk_weights) + topk_weights, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 5cd42fa8b2e0..1f03817b0544 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -19,12 +19,8 @@ StandardDispatchCombine) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) -from vllm.model_executor.layers.quantization.utils.int8_utils import ( - per_token_group_quant_int8, per_token_quant_int8) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, - _resize_cache) +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, moe_kernel_quantize_input) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -967,6 +963,20 @@ def get_config_dtype_str( return None +# TODO: use scalar_type? +def get_config_qtype( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, +) -> Optional[torch.dtype]: + if use_fp8_w8a8: + return torch.float8_e4m3fn + elif use_int8_w8a8: + return torch.int8 + return None + + def inplace_fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -1156,6 +1166,7 @@ def fused_experts(hidden_states: torch.Tensor, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, ) else: return dispatch_fused_experts_func(inplace)( @@ -1182,59 +1193,6 @@ def fused_experts(hidden_states: torch.Tensor, block_shape=block_shape) -def moe_kernel_prepare_input( - A: torch.Tensor, - B: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, - block_shape: Optional[List[int]] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if use_fp8_w8a8: - assert B_scale is not None - if block_shape is None: - # If weights are per-channel (per_channel_quant=True), then - # activations apply per-token quantization. Otherwise, assume - # activation tensor-wise fp8 quantization, dynamic or static - A, A_scale = ops.scaled_fp8_quant( - A, A_scale, use_per_token_if_dynamic=per_channel_quant) - else: - # activation block-wise fp8 quantization - assert len(block_shape) == 2 - _, block_k = block_shape[0], block_shape[1] - A, A_scale = per_token_group_quant_fp8(A, block_k) - assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] - # assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] - # assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] - elif use_int8_w8a8: - assert B_scale is not None - if block_shape is None: - # activation channel-wise int8 quantization - assert (per_channel_quant - ), "int8 quantization only supports block or channel-wise" - A, A_scale = per_token_quant_int8(A) - else: - # activation block-wise int8 quantization - assert len(block_shape) == 2 - _, block_k = block_shape[0], block_shape[1] - A, A_scale = per_token_group_quant_int8(A, block_k) - assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] - # assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] - # assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] - elif use_int8_w8a16 or use_int4_w4a16: - assert B_scale is not None - assert block_shape is None or block_shape[0] == 0 - else: - assert A_scale is None - assert B_scale is None - - return A, A_scale - - def fused_experts_impl(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -1288,6 +1246,11 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype) + qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16) + get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, @@ -1350,15 +1313,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input( + qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, - B=w1, A_scale=a1_scale, - B_scale=w1_scale, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, + qtype=qtype, per_channel_quant=per_channel_quant, block_shape=block_shape) @@ -1369,7 +1327,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, invoke_fused_moe_kernel(qcurr_hidden_states, w1, intermediate_cache1, - qa1_scale, + a1q_scale, w1_scale, w1_zp, curr_topk_weights, @@ -1396,22 +1354,17 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") - qintermediate_cache2, qa2_scale = moe_kernel_prepare_input( + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, - B=w2, A_scale=a2_scale, - B_scale=w2_scale, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, + qtype=qtype, per_channel_quant=per_channel_quant, block_shape=block_shape) invoke_fused_moe_kernel(qintermediate_cache2, w2, intermediate_cache3, - qa2_scale, + a2q_scale, w2_scale, w2_zp, curr_topk_weights, @@ -1553,17 +1506,25 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, use_fp8_w8a8: bool, + use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, + per_channel_quant: bool, block_shape: Optional[List[int]] = None, block_m: Optional[int] = None, ): super().__init__() self.use_fp8_w8a8 = use_fp8_w8a8 self.use_int4_w4a16 = use_int4_w4a16 + self.use_int8_w8a8 = use_int8_w8a8 self.use_int8_w8a16 = use_int8_w8a16 self.block_shape = block_shape self.block_m = block_m + self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16) + self.per_channel_quant = per_channel_quant def workspace_shapes( self, @@ -1674,8 +1635,10 @@ def apply( config, compute_type=compute_type, use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_channel_quant, block_shape=self.block_shape) self.activation(activation, intermediate_cache2, @@ -1683,12 +1646,9 @@ def apply( a2q_scale: Optional[torch.Tensor] = None - if self.use_fp8_w8a8: - qintermediate_cache2, a2q_scale = _fp8_quantize( - intermediate_cache2, a2_scale, self.block_shape) - else: - qintermediate_cache2 = intermediate_cache2 - a2q_scale = a2_scale + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( + intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant, + self.block_shape) invoke_fused_moe_kernel(qintermediate_cache2, w2, @@ -1705,8 +1665,10 @@ def apply( config, compute_type=compute_type, use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_channel_quant, block_shape=self.block_shape) return intermediate_cache3 @@ -1714,18 +1676,30 @@ def apply( def modular_triton_fused_moe( use_fp8_w8a8: bool, + use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, + per_channel_quant: bool, block_shape: Optional[List[int]] = None, ) -> mk.FusedMoEModularKernel: + qtype = get_config_qtype( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + ) return mk.FusedMoEModularKernel( StandardDispatchCombine( - quant_dtype=torch.float8_e4m3fn if use_fp8_w8a8 else None, - block_shape=block_shape), + quant_dtype=qtype, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + ), TritonExperts( - use_fp8_w8a8, - use_int8_w8a16, - use_int4_w4a16, - block_shape, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, ), ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index b517f6ee13c5..aab7658ae641 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -84,9 +84,11 @@ def dispatch( a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform any quantization (and/or) dispatching needed @@ -95,7 +97,8 @@ def dispatch( - a1_scale: Optional scales for a1 - a2_scale: Optional scales for the second MoE gemm. Required to make sure the quantization is consistent for both gemms. - - topk_ids: The topk_ids. + - topk_ids: The topk ids. + - topk_weights: The topk weights. - num_experts: The total number of experts in the global expert space. - expert_map: A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. @@ -113,6 +116,7 @@ def combine( fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, ) -> None: """ Perform any combine plus apply weights and perform a reduction on the @@ -169,7 +173,7 @@ def activation(self, activation: str, output: torch.Tensor, @abstractmethod def apply( self, - a1q: torch.Tensor, + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, @@ -191,7 +195,8 @@ def apply( (MoE) layer using two sets of weights, w1 and w2. Parameters: - - a1q: (torch.Tensor): The (quantized) input tensor to the MoE layer. + - hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE + layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_ids (torch.Tensor): A map of row to expert id. @@ -263,6 +268,7 @@ def forward( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets @@ -292,6 +298,9 @@ def forward( w2. - a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1. - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is + 1. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -318,34 +327,29 @@ def forward( dtype=workspace_dtype) a1q, a1q_scale, expert_num_tokens = self.dispatch_combine.dispatch( - a1, - a1_scale, - a2_scale, - topk_ids, - global_num_experts, - expert_map, - ) + a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts, + expert_map, apply_router_weight_on_input) fused_out = self.fused_experts.apply( a1q, w1, w2, topk_ids, - activation, - global_num_experts, - expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, - a1q_scale, - a2_scale, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, workspace13=workspace13, workspace2=workspace2, expert_num_tokens=expert_num_tokens, ) self.dispatch_combine.combine(output, fused_out, topk_weights, - topk_ids) + topk_ids, apply_router_weight_on_input) return output diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 93a3d8ab9c18..0d89f3f22c31 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -56,6 +56,7 @@ def _moe_unpermute_and_reduce( curr_hidden: torch.Tensor, inv_perm: Optional[torch.Tensor], topk_weight: torch.Tensor, + apply_router_weight_on_input: bool, ) -> None: """ Unpermute the final result and apply topk_weights, then perform the final @@ -66,5 +67,6 @@ def _moe_unpermute_and_reduce( if inv_perm is not None: curr_hidden = curr_hidden[inv_perm, ...] curr_hidden = curr_hidden.view(-1, topk, K) - curr_hidden.mul_(topk_weight.view(M, -1, 1)) + if not apply_router_weight_on_input: + curr_hidden.mul_(topk_weight.view(M, -1, 1)) ops.moe_sum(curr_hidden, out) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index d35cfaccd39d..90a4833948f8 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -6,7 +6,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) # Note use: layer.get_all_to_all() to get an AllToAll instance @@ -34,27 +35,33 @@ def dispatch( a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], + rank_topk_weights: torch.Tensor, rank_topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Is this always going to be a1.device? device = a1.device - if self.quant_dtype == torch.float8_e4m3fn: - per_act_token = a1_scale.numel( - ) != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + assert expert_map is None, "NYI" - a1q, a1q_scale = _fp8_quantize( - a1, - a1_scale, - self.block_shape, - per_act_token, - ) - else: - a1q = a1 - a1q_scale = a1_scale + # TBD + assert not apply_router_weight_on_input + if apply_router_weight_on_input: + topk = rank_topk_ids.shape[1] + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1 = a1 * rank_topk_weights.to(a1.dtype) + + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, + self.quant_dtype, + per_act_token, + self.block_shape) expert_num_tokens = torch.empty( num_experts, @@ -103,6 +110,7 @@ def combine( fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, ) -> None: # This argument is optional bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens @@ -110,6 +118,11 @@ def combine( assert output.shape[0] == self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] + # Set weights to 1? + assert not apply_router_weight_on_input + if apply_router_weight_on_input: + topk_weights = torch.ones_like(topk_weights) + self.a2a.combine(out_tokens=output, indices=topk_ids, weights=topk_weights, diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index eff39c0f7922..d53da1d7926e 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -7,6 +7,8 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) +from vllm.model_executor.layers.quantization.utils.int8_utils import ( + per_token_group_quant_int8, per_token_quant_int8) from vllm.utils import cdiv @@ -22,8 +24,8 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: def _fp8_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], + per_act_token: bool, block_shape: Optional[List[int]] = None, - per_act_token: bool = False, # make sure this is the same default as op ) -> Tuple[torch.Tensor, torch.Tensor]: """ Perform fp8 quantization on the inputs. If a block_shape @@ -37,9 +39,53 @@ def _fp8_quantize( _, block_k = block_shape[0], block_shape[1] A, A_scale = per_token_group_quant_fp8(A, block_k) assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + + return A, A_scale + + +def _int8_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token: bool, + block_shape: Optional[List[int]] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform int8 quantization on the inputs. If a block_shape + is provided, the output will be blocked. + """ + + # If weights are per-channel (per_channel_quant=True), then + # activations apply per-token quantization. Otherwise, assume + # activation tensor-wise fp8/int8 quantization, dynamic or static + if block_shape is None: + assert per_act_token, \ + "int8 quantization only supports block or channel-wise" + A, A_scale = per_token_quant_int8(A) + else: + assert len(block_shape) == 2 + _, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_int8(A, block_k) + assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + return A, A_scale +def moe_kernel_quantize_input( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + qtype: Optional[torch.dtype], + per_channel_quant: bool, + block_shape: Optional[List[int]] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if qtype == torch.float8_e4m3fn: + return _fp8_quantize(A, A_scale, per_channel_quant, block_shape) + elif qtype == torch.int8: + return _int8_quantize(A, A_scale, per_channel_quant, block_shape) + else: + assert A_scale is None + return A, A_scale + + def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: """ A permutation routine that works on fp8 types. From 4664e0f7cc133d7104d9b468e1c29324f2da8669 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 02:28:00 +0000 Subject: [PATCH 027/171] fix lint Signed-off-by: Bill Nell --- .../layers/fused_moe/cutlass_moe.py | 2 + .../layers/fused_moe/pplx_dispatch_combine.py | 43 ++++++++++++------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d718ac1f3f3a..e52751eddf2c 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -64,6 +64,8 @@ def apply( ) -> torch.Tensor: a1q = hidden_states + assert w1_scale is not None + assert w2_scale is not None assert w1.dtype == torch.float8_e4m3fn assert w2.dtype == torch.float8_e4m3fn assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1" diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 90a4833948f8..658705515b43 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -5,15 +5,13 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) # Note use: layer.get_all_to_all() to get an AllToAll instance # The max_num_tokens, world_size and dp_size must be the same -# as the ones used to create the AllToAll. Unfortunately, there's -# no way(?) to extract this info from AllToAll +# as the ones used to create the AllToAll. class PplxDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): def __init__(self, @@ -21,13 +19,16 @@ def __init__(self, max_num_tokens: int, world_size: int, dp_size: int, + rank: int, quant_dtype: Optional[torch.dtype] = None, block_shape: Optional[List[int]] = None): super().__init__() self.a2a = a2a self.block_shape = block_shape self.max_num_tokens = max_num_tokens - self.dp_num_tokens = max_num_tokens * (world_size // dp_size) + self.world_size = world_size + self.dp_size = dp_size + self.rank = rank self.quant_dtype = quant_dtype def dispatch( @@ -39,8 +40,8 @@ def dispatch( rank_topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + apply_router_weight_on_input: bool, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: # Is this always going to be a1.device? device = a1.device @@ -63,14 +64,19 @@ def dispatch( per_act_token, self.block_shape) + rem_experts = num_experts % self.world_size + num_local_experts = ((num_experts // self.world_size) + + (1 if self.rank < rem_experts else 0)) + expert_num_tokens = torch.empty( - num_experts, + num_local_experts, dtype=torch.int32, device=device, ) + num_dp = self.world_size // self.dp_size expert_x = torch.empty( - (num_experts, self.dp_num_tokens, a1q.shape[-1]), + (num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]), dtype=a1q.dtype, device=device, ) @@ -90,8 +96,14 @@ def dispatch( device=device, ) - # This argument is optional - bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens + # This argument is optional, defaults to indices.shape[0] + # This causes a deadlock???? + #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens + #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) + bound_m = None + + # TODO: optimize this? + indices = rank_topk_ids.to(dtype=torch.uint32) self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, @@ -99,10 +111,10 @@ def dispatch( out_expert_x_scale=expert_x_scale, dp_x=a1q, dp_x_scale=a1q_scale, - indices=rank_topk_ids, + indices=indices, bound_m=bound_m, ) - return expert_x, expert_x_scale + return expert_x, expert_x_scale, expert_num_tokens def combine( self, @@ -113,9 +125,10 @@ def combine( apply_router_weight_on_input: bool, ) -> None: # This argument is optional - bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens + #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens + bound_m = None - assert output.shape[0] == self.max_num_tokens + assert output.shape[0] <= self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] # Set weights to 1? @@ -124,7 +137,7 @@ def combine( topk_weights = torch.ones_like(topk_weights) self.a2a.combine(out_tokens=output, - indices=topk_ids, + indices=topk_ids.to(torch.uint32), weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m) From 42f12d755ddb11451c17f430aa6b13e7c88e5bac Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 03:21:26 +0000 Subject: [PATCH 028/171] add pplx tests Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 598 ++++++++++++++++++ .../layers/fused_moe/fused_batched_moe.py | 175 +++++ .../layers/fused_moe/fused_moe.py | 14 + 3 files changed, 787 insertions(+) create mode 100644 tests/kernels/moe/test_pplx_moe.py create mode 100644 vllm/model_executor/layers/fused_moe/fused_batched_moe.py diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py new file mode 100644 index 000000000000..cab9990b16b5 --- /dev/null +++ b/tests/kernels/moe/test_pplx_moe.py @@ -0,0 +1,598 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the MOE layers. + +Run `pytest tests/kernels/test_pplx_moe.py`. +""" +import dataclasses +import os +import pytest +import torch +import traceback + +from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] +from typing import Callable, Concatenate, Optional, ParamSpec, Tuple + +from pplx_kernels import AllToAll +from pplx_kernels.nvshmem import ( + nvshmem_alloc_empty_unique_id, + nvshmem_finalize, + nvshmem_get_unique_id, + nvshmem_init, +) + +import vllm.model_executor.layers.fused_moe # noqa +from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, + torch_moe, torch_moe_single) +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size) +from vllm.platforms import current_platform + +from vllm.model_executor.layers.activation import SiluAndMul + +from vllm.model_executor.layers.fused_moe.fused_moe import ( + TritonExperts, fused_experts) +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedDispatchCombine, BatchedExperts) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel +from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import PplxDispatchCombine + +NUM_EXPERTS = [8, 64] +EP_SIZE = [1, 4] +TOP_KS = [2, 6] + +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + +P = ParamSpec("P") + +require_multi_node = pytest.mark.skipif( + "MASTER_ADDR" not in os.environ, + reason="Requires multi-node environment", +) + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + *args, + **kwargs, + ) + except Exception as ex: + print(ex) + traceback.print_exception(ex) + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + "tcp://localhost:29500", + worker, + ) + + args, + nprocs=world_size, + join=True, + ) + + +def parallel_launch_from_env( + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + """ + Launches a worker function in parallel across all processes in the current + environment. The environment must have the following variables set: + - WORLD_SIZE: The total number of processes. + - WORLD_LOCAL_SIZE: The number of processes on the current node. + - NODE_RANK: The rank of the current + - MASTER_ADDR: The address of the master process. + - MASTER_PORT: The port of the master process. + """ + assert not kwargs + world_size = int(os.environ["WORLD_SIZE"]) + world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) + node_rank = int(os.environ["NODE_RANK"]) + assert "MASTER_ADDR" in os.environ + assert "MASTER_PORT" in os.environ + spawn( + _worker_parallel_launch, + args=( + world_size, + world_local_size, + node_rank, + "env://", + worker, + ) + + args, + nprocs=world_local_size, + join=True, + ) + + +def torch_dispatch( + a: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + max_num_tokens: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a.shape[0] + + num_tokens = a.shape[0] + topk = topk_ids.shape[1] + + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + if max_num_tokens is None: + max_num_tokens = tokens_per_expert.max() + + b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), + dtype=a.dtype, device=a.device) + + #print(f"b_a shape {b_a.shape}") + + token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) + + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = token_counts[expert_id] + b_a[expert_id, idx:idx+1, :] = a[token, :] + token_counts[expert_id] = token_counts[expert_id] + 1 + + return b_a, tokens_per_expert + + +def torch_combine(b_out, topk_weight, topk_ids): + num_tokens, topk = topk_ids.shape + num_experts = b_out.shape[0] + K = b_out.shape[-1] + out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + for i in range(expert_ids.numel()): + expert_id = expert_ids[i] + idx = expert_counts[expert_id] + out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + return out + + +def torch_batched_moe(a, w1, w2, topk_weight, topk_ids): + num_experts = w1.shape[0] + b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts) + assert b_a.dim() == 3 + num_tokens, topk = topk_ids.shape + _, max_num_tokens, K = b_a.shape + assert num_experts == b_a.shape[0] and K == w2.shape[1] + out = torch.zeros((num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device) + tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device) + for expert in range(num_experts): + num = tokens_per_expert[expert] + if num > 0: + torch.ops._C.silu_and_mul(tmp[:num], b_a[expert,:num,:] @ w1[expert].transpose(0, 1)) + out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) + + return torch_combine(out, topk_weight, topk_ids) + + +# TODO: same as torch_moe but with fused_topk factored out. +def torch_moe2(a, w1, w2, topk_weight, topk_ids): + M, K = a.shape + topk = topk_ids.shape[1] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + num_experts = w1.shape[0] + for i in range(num_experts): + mask = (topk_ids == i).view(-1) + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids = fused_topk(a, score, topk, False) + + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + + if True: + triton_output = torch_batched_moe(a, + w1, + w2, + topk_weight, + topk_ids) + else: + b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) + triton_output = fused_batched_experts( + b_a, + w1, + w2, + topk_weight, + topk_ids, + global_num_experts=e + ) + + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + +def rank_chunk(num, r, w): + rem = num % w + return (num // w) + (1 if r < rem else 0) + + +def chunk_by_rank(t, r, w): + chunk = rank_chunk(t.shape[0], r, w) + #print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") + return t[(r * chunk):(r + 1)*chunk] + + +def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): + assert torch.cuda.current_device() == pgi.local_rank + + num_tokens, hidden_dim = a.shape + num_experts = w1.shape[0] + block_size = 128 + device = pgi.device + rank = pgi.rank + world_size = pgi.world_size + rank_num_tokens = rank_chunk(num_tokens, rank, world_size) + max_num_tokens = num_tokens + #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") + + ata = AllToAll( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=pgi.world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=( + 0 + if a.dtype.itemsize != 1 + else ( + (hidden_dim + block_size - 1) + // block_size + * torch.float32.itemsize + ) + ), + ) + + dispatch_combine = PplxDispatchCombine( + ata, + max_num_tokens, + pgi.world_size, + dp_size, + rank, + a.dtype, + ) + + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + score_chunk = chunk_by_rank(scores, rank, world_size).to(device) + chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) + + b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( + a_chunk, + None, + None, + chunk_topk_weight, + chunk_topk_ids, + num_experts, # store at PplxDispatchCombine creation? + None, + False, + ) + + naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, num_experts) + + torch.distributed.all_reduce(tokens_per_expert) + tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32) + + torch.testing.assert_close(tokens_per_expert, expert_num_tokens, atol=0, rtol=0) + + b_a = b_a * 1.5 + + out = torch.full( + (rank_num_tokens * world_size, hidden_dim), + torch.nan, + dtype=a.dtype, + device=device, + ) + + dispatch_combine.combine( + out, + b_a, + chunk_topk_weight, + chunk_topk_ids, + False, + ) + torch.cuda.synchronize() + + ata.destroy() + + return out[:rank_num_tokens] + + +def _pplx_dispatch_combine( + pgi: ProcessGroupInfo, + dp_size: int, + m, n, k, e, + topk: int, + dtype: torch.dtype, +): + uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + device = pgi.device + + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 + score = torch.randn((m, e), device=device, dtype=dtype) + + topk_weight, topk_ids = fused_topk(a, score, topk, False) + + a_rep = torch.repeat_interleave(a, topk, dim=0) + + torch_output = (a_rep.view(-1, topk, k) * 1.5 * topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) + + pplx_output = torch_pplx_dispatch_combine(pgi, + dp_size, + a, + w1, + w2, + score, + topk) + + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() + + +@pytest.mark.parametrize("m", [4, 32, 64, 222]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions? % 128? +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) +def test_pplx_dispatch_combine( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + world_dp_size: Tuple[int, int], +): + current_platform.seed_everything(7) + world_size, dp_size = world_dp_size + + parallel_launch( + world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, topk, dtype + ) + + +def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): + assert torch.cuda.current_device() == pgi.local_rank + + num_tokens, hidden_dim = a.shape + num_experts = w1.shape[0] + block_size = 128 + device = pgi.device + rank = pgi.rank + world_size = pgi.world_size + rank_num_tokens = rank_chunk(num_tokens, rank, world_size) + max_num_tokens = num_tokens + + ata = AllToAll( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=pgi.world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=( + 0 + if a.dtype.itemsize != 1 + else ( + (hidden_dim + block_size - 1) + // block_size + * torch.float32.itemsize + ) + ), + ) + + w1 = w1.to(device) + w2 = w2.to(device) + + dispatch_combine = PplxDispatchCombine( + ata, + max_num_tokens, + pgi.world_size, + dp_size, + rank, + a.dtype, + ) + + experts = BatchedExperts(rank, pgi.world_size, max_num_tokens) + + fused_experts = FusedMoEModularKernel( + dispatch_combine, + experts, + ) + + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + score_chunk = chunk_by_rank(scores, rank, world_size).to(device) + chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) + + out = fused_experts( + a_chunk, + # Chunking weights like this only works for batched format + chunk_by_rank(w1, rank, world_size), + chunk_by_rank(w2, rank, world_size), + #w1, + #w2, + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts #? num_local_experts? + ) + + torch.cuda.synchronize() + + ata.destroy() + + return out[:rank_num_tokens] + + +def _pplx_moe( + pgi: ProcessGroupInfo, + dp_size: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + dtype: torch.dtype, +): + uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + + m, k = a.shape + e, _, n = w2.shape + + torch.set_printoptions(profile="full") + + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids = fused_topk(a, score, topk, False) + + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + + pplx_output = torch_pplx_moe(pgi, + dp_size, + a, + w1, + w2, + score, + topk) + + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() + + +# TODO: M == 1 doesn't work +@pytest.mark.parametrize("m", [2, 3, 32, 45, 64, 222]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) +def test_pplx_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + world_dp_size: Tuple[int, int], +): + current_platform.seed_everything(7) + world_size, dp_size = world_dp_size + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + parallel_launch( + world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, dtype + ) + diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py new file mode 100644 index 000000000000..a39d08b83768 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Fused batched MoE kernel.""" +from typing import List, Optional, Tuple + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.utils import _resize_cache + + +class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + def __init__(self, + world_size: int, + rank: int): + super().__init__() + self.world_size = world_size + self.rank = rank + + def dispatch( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a1.shape[0] + + if apply_router_weight_on_input: + topk = topk_ids.shape[1] + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1.mul_(topk_weights.to(a1.dtype)) + + num_tokens = a1.shape[0] + topk = topk_ids.shape[1] + + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + max_num_tokens = tokens_per_expert.max() + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=a1.device) + + b_a1 = torch.zeros((num_experts, max_num_tokens, a1.shape[1]), + dtype=a1.dtype, device=a1.device) + + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = expert_counts[expert_id] + b_a1[expert_id, idx:idx+1, :] = a1[token, :] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + return b_a1, a1_scale, tokens_per_expert + + def combine( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + num_tokens = topk_ids.shape[0] + num_experts = fused_expert_output.shape[0] + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=fused_expert_output.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + for i in range(topk_ids.shape[1]): + expert_id = expert_ids[i] + if expert_id < num_experts: + idx = expert_counts[expert_id] + if apply_router_weight_on_input: + output[token, :] = output[token, :] + fused_expert_output[expert_id, idx:idx+1, :] + else: + output[token, :] = output[token, :] + fused_expert_output[expert_id, idx:idx+1, :] * topk_weights[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + +def rank_chunk(num, r, w): + rem = num % w + return (num // w) + (1 if r < rem else 0) + + +class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + rank: int = 0, + world_size: int = 1, + max_num_tokens: Optional[int] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + block_shape: Optional[List[int]] = None, + block_m: Optional[int] = None, + ): + super().__init__() + assert not use_fp8_w8a8 + assert not use_int4_w4a16 + assert not use_int8_w8a16 + assert block_shape is None + assert block_m is None + self.max_num_tokens = max_num_tokens + self.rank = rank + self.world_size = world_size + assert not use_fp8_w8a8, "NYI" + assert not use_int8_w8a8, "NYI" + assert not use_int8_w8a16, "NYI" + assert not use_int4_w4a16, "NYI" + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: + max_num_tokens = a.shape[1] if self.max_num_tokens is None else self.max_num_tokens + workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack + workspace2 = max_num_tokens * N + return (workspace13, workspace2, a.dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + assert hidden_states.dim() == 3 + assert expert_num_tokens is not None + num_tokens = topk_ids.shape[0] + _, tmp_max_num_tokens, K = hidden_states.shape + max_num_tokens = tmp_max_num_tokens if self.max_num_tokens is None else self.max_num_tokens + num_experts = global_num_experts + out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) + num_local_experts = expert_num_tokens.numel() + + # TODO: don't need world_size or rank if expert_base always == 0 + #assert w1.shape[0] == num_experts, f"{w1.shape} == {num_experts}" + #expert_base = rank_chunk(w1.shape[0], self.rank, self.world_size) * self.rank + expert_base = 0 + + for expert in range(num_local_experts): + num = expert_num_tokens[expert] + assert num <= max_num_tokens, f"{num}, {max_num_tokens}" + if num > 0: + tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) + self.activation( + activation, + tmp, + hidden_states[expert,:num,:] @ w1[expert_base + expert].transpose(0, 1) + ) + out[expert, :num, :] = tmp @ w2[expert_base + expert].transpose(0, 1) + + return out diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1f03817b0544..e153e165b621 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -486,6 +486,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + if use_fp8_w8a8: + assert B_scale is not None + assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) + == B_scale.shape[-2]) + assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) + == B_scale.shape[-1]) + + elif use_int8_w8a16 or use_int4_w4a16: + assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 + else: + assert A_scale is None + assert B_scale is None + M = A.shape[0] num_tokens = M * top_k From dc0a640d63849b8115151f775764960fea274faa Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 03:26:54 +0000 Subject: [PATCH 029/171] lint Signed-off-by: Bill Nell --- .../cutlass_benchmarks/w8a8_benchmarks.py | 2 +- tests/kernels/moe/test_pplx_moe.py | 197 ++++++++---------- .../layers/fused_moe/fused_batched_moe.py | 60 ++++-- 3 files changed, 123 insertions(+), 136 deletions(-) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index e7b742d8bec9..09462560f402 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -11,9 +11,9 @@ import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement -from utils import make_rand_tensors from weight_shapes import WEIGHT_SHAPES +from utils import make_rand_tensors from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( w8a8_block_fp8_matmul) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index cab9990b16b5..97ecf141851c 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -5,37 +5,29 @@ """ import dataclasses import os -import pytest -import torch import traceback +from typing import Callable, Concatenate, Optional, ParamSpec -from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] -from typing import Callable, Concatenate, Optional, ParamSpec, Tuple - +import pytest +import torch from pplx_kernels import AllToAll -from pplx_kernels.nvshmem import ( - nvshmem_alloc_empty_unique_id, - nvshmem_finalize, - nvshmem_get_unique_id, - nvshmem_init, -) +from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_finalize, nvshmem_get_unique_id, + nvshmem_init) +from torch.multiprocessing import ( + spawn) # pyright: ignore[reportPrivateImportUsage] import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, - torch_moe, torch_moe_single) from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, moe_align_block_size) -from vllm.platforms import current_platform - from vllm.model_executor.layers.activation import SiluAndMul - -from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, fused_experts) from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedDispatchCombine, BatchedExperts) -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel -from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import PplxDispatchCombine + BatchedExperts) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import ( + PplxDispatchCombine) +from vllm.platforms import current_platform NUM_EXPERTS = [8, 64] EP_SIZE = [1, 4] @@ -122,8 +114,7 @@ def parallel_launch( 0, "tcp://localhost:29500", worker, - ) - + args, + ) + args, nprocs=world_size, join=True, ) @@ -157,8 +148,7 @@ def parallel_launch_from_env( node_rank, "env://", worker, - ) - + args, + ) + args, nprocs=world_local_size, join=True, ) @@ -169,19 +159,21 @@ def torch_dispatch( topk_ids: torch.Tensor, num_experts: int, max_num_tokens: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a.shape[0] num_tokens = a.shape[0] topk = topk_ids.shape[1] - tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + tokens_per_expert = torch.bincount(topk_ids.view(-1), + minlength=num_experts) if max_num_tokens is None: max_num_tokens = tokens_per_expert.max() b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), - dtype=a.dtype, device=a.device) + dtype=a.dtype, + device=a.device) #print(f"b_a shape {b_a.shape}") @@ -191,7 +183,7 @@ def torch_dispatch( for j in range(topk): expert_id = topk_ids[token, j] idx = token_counts[expert_id] - b_a[expert_id, idx:idx+1, :] = a[token, :] + b_a[expert_id, idx:idx + 1, :] = a[token, :] token_counts[expert_id] = token_counts[expert_id] + 1 return b_a, tokens_per_expert @@ -202,13 +194,16 @@ def torch_combine(b_out, topk_weight, topk_ids): num_experts = b_out.shape[0] K = b_out.shape[-1] out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) - expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) + expert_counts = torch.zeros(num_experts, + dtype=torch.int, + device=b_out.device) for token in range(num_tokens): expert_ids = topk_ids[token] for i in range(expert_ids.numel()): expert_id = expert_ids[i] idx = expert_counts[expert_id] - out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i] + out[token, :] = out[token, :] + b_out[expert_id, idx:idx + + 1, :] * topk_weight[token, i] expert_counts[expert_id] = expert_counts[expert_id] + 1 return out @@ -220,13 +215,18 @@ def torch_batched_moe(a, w1, w2, topk_weight, topk_ids): assert b_a.dim() == 3 num_tokens, topk = topk_ids.shape _, max_num_tokens, K = b_a.shape - assert num_experts == b_a.shape[0] and K == w2.shape[1] - out = torch.zeros((num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device) - tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device) + assert num_experts == b_a.shape[0] and w2.shape[1] == K + out = torch.zeros((num_experts, max_num_tokens, K), + dtype=b_a.dtype, + device=b_a.device) + tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), + dtype=b_a.dtype, + device=b_a.device) for expert in range(num_experts): num = tokens_per_expert[expert] if num > 0: - torch.ops._C.silu_and_mul(tmp[:num], b_a[expert,:num,:] @ w1[expert].transpose(0, 1)) + torch.ops._C.silu_and_mul( + tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)) out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) return torch_combine(out, topk_weight, topk_ids) @@ -249,7 +249,7 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) -@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -272,25 +272,8 @@ def test_fused_moe_batched_experts( with set_current_vllm_config(vllm_config): topk_weight, topk_ids = fused_topk(a, score, topk, False) - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - - if True: - triton_output = torch_batched_moe(a, - w1, - w2, - topk_weight, - topk_ids) - else: - b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) - triton_output = fused_batched_experts( - b_a, - w1, - w2, - topk_weight, - topk_ids, - global_num_experts=e - ) + triton_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @@ -303,7 +286,7 @@ def rank_chunk(num, r, w): def chunk_by_rank(t, r, w): chunk = rank_chunk(t.shape[0], r, w) #print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") - return t[(r * chunk):(r + 1)*chunk] + return t[(r * chunk):(r + 1) * chunk] def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): @@ -317,7 +300,6 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): world_size = pgi.world_size rank_num_tokens = rank_chunk(num_tokens, rank, world_size) max_num_tokens = num_tokens - #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") ata = AllToAll( max_num_tokens=max_num_tokens, @@ -328,15 +310,9 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): dp_size=dp_size, hidden_dim=hidden_dim, hidden_dim_bytes=hidden_dim * a.dtype.itemsize, - hidden_dim_scale_bytes=( - 0 - if a.dtype.itemsize != 1 - else ( - (hidden_dim + block_size - 1) - // block_size - * torch.float32.itemsize - ) - ), + hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else + ((hidden_dim + block_size - 1) // block_size * + torch.float32.itemsize)), ) dispatch_combine = PplxDispatchCombine( @@ -350,7 +326,8 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): a_chunk = chunk_by_rank(a, rank, world_size).to(device) score_chunk = chunk_by_rank(scores, rank, world_size).to(device) - chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) + chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, + False) b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( a_chunk, @@ -358,17 +335,22 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): None, chunk_topk_weight, chunk_topk_ids, - num_experts, # store at PplxDispatchCombine creation? + num_experts, # store at PplxDispatchCombine creation? None, False, ) - naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, num_experts) + naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, + num_experts) torch.distributed.all_reduce(tokens_per_expert) - tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32) + tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, + world_size).to(dtype=torch.int32) - torch.testing.assert_close(tokens_per_expert, expert_num_tokens, atol=0, rtol=0) + torch.testing.assert_close(tokens_per_expert, + expert_num_tokens, + atol=0, + rtol=0) b_a = b_a * 1.5 @@ -396,11 +378,15 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): def _pplx_dispatch_combine( pgi: ProcessGroupInfo, dp_size: int, - m, n, k, e, + m, + n, + k, + e, topk: int, dtype: torch.dtype, ): - uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) device = pgi.device @@ -414,17 +400,14 @@ def _pplx_dispatch_combine( a_rep = torch.repeat_interleave(a, topk, dim=0) - torch_output = (a_rep.view(-1, topk, k) * 1.5 * topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) + torch_output = (a_rep.view(-1, topk, k) * 1.5 * + topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) - pplx_output = torch_pplx_dispatch_combine(pgi, - dp_size, - a, - w1, - w2, - score, + pplx_output = torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, score, topk) - torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) @@ -437,7 +420,7 @@ def _pplx_dispatch_combine( @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) def test_pplx_dispatch_combine( m: int, n: int, @@ -445,14 +428,13 @@ def test_pplx_dispatch_combine( e: int, topk: int, dtype: torch.dtype, - world_dp_size: Tuple[int, int], + world_dp_size: tuple[int, int], ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size - parallel_launch( - world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, topk, dtype - ) + parallel_launch(world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, + topk, dtype) def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): @@ -476,15 +458,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): dp_size=dp_size, hidden_dim=hidden_dim, hidden_dim_bytes=hidden_dim * a.dtype.itemsize, - hidden_dim_scale_bytes=( - 0 - if a.dtype.itemsize != 1 - else ( - (hidden_dim + block_size - 1) - // block_size - * torch.float32.itemsize - ) - ), + hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else + ((hidden_dim + block_size - 1) // block_size * + torch.float32.itemsize)), ) w1 = w1.to(device) @@ -508,7 +484,8 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): a_chunk = chunk_by_rank(a, rank, world_size).to(device) score_chunk = chunk_by_rank(scores, rank, world_size).to(device) - chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) + chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, + False) out = fused_experts( a_chunk, @@ -519,7 +496,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): #w2, chunk_topk_weight, chunk_topk_ids, - global_num_experts=num_experts #? num_local_experts? + global_num_experts=num_experts #? num_local_experts? ) torch.cuda.synchronize() @@ -539,7 +516,8 @@ def _pplx_moe( topk: int, dtype: torch.dtype, ): - uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) @@ -553,15 +531,10 @@ def _pplx_moe( torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - pplx_output = torch_pplx_moe(pgi, - dp_size, - a, - w1, - w2, - score, - topk) + pplx_output = torch_pplx_moe(pgi, dp_size, a, w1, w2, score, topk) - torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) @@ -575,7 +548,7 @@ def _pplx_moe( @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) def test_pplx_moe( m: int, n: int, @@ -583,7 +556,7 @@ def test_pplx_moe( e: int, topk: int, dtype: torch.dtype, - world_dp_size: Tuple[int, int], + world_dp_size: tuple[int, int], ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size @@ -592,7 +565,5 @@ def test_pplx_moe( w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) - parallel_launch( - world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, dtype - ) - + parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, + dtype) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index a39d08b83768..56b1b343c86e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -9,9 +9,8 @@ class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__(self, - world_size: int, - rank: int): + + def __init__(self, world_size: int, rank: int): super().__init__() self.world_size = world_size self.rank = rank @@ -40,18 +39,22 @@ def dispatch( num_tokens = a1.shape[0] topk = topk_ids.shape[1] - tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + tokens_per_expert = torch.bincount(topk_ids.view(-1), + minlength=num_experts) max_num_tokens = tokens_per_expert.max() - expert_counts = torch.zeros(num_experts, dtype=torch.int, device=a1.device) + expert_counts = torch.zeros(num_experts, + dtype=torch.int, + device=a1.device) b_a1 = torch.zeros((num_experts, max_num_tokens, a1.shape[1]), - dtype=a1.dtype, device=a1.device) + dtype=a1.dtype, + device=a1.device) for token in range(num_tokens): for j in range(topk): expert_id = topk_ids[token, j] idx = expert_counts[expert_id] - b_a1[expert_id, idx:idx+1, :] = a1[token, :] + b_a1[expert_id, idx:idx + 1, :] = a1[token, :] expert_counts[expert_id] = expert_counts[expert_id] + 1 return b_a1, a1_scale, tokens_per_expert @@ -66,7 +69,9 @@ def combine( ) -> None: num_tokens = topk_ids.shape[0] num_experts = fused_expert_output.shape[0] - expert_counts = torch.zeros(num_experts, dtype=torch.int, device=fused_expert_output.device) + expert_counts = torch.zeros(num_experts, + dtype=torch.int, + device=fused_expert_output.device) for token in range(num_tokens): expert_ids = topk_ids[token] for i in range(topk_ids.shape[1]): @@ -74,9 +79,14 @@ def combine( if expert_id < num_experts: idx = expert_counts[expert_id] if apply_router_weight_on_input: - output[token, :] = output[token, :] + fused_expert_output[expert_id, idx:idx+1, :] + output[token, :] = output[ + token, :] + fused_expert_output[expert_id, + idx:idx + 1, :] else: - output[token, :] = output[token, :] + fused_expert_output[expert_id, idx:idx+1, :] * topk_weights[token, i] + output[ + token, :] = output[token, :] + fused_expert_output[ + expert_id, + idx:idx + 1, :] * topk_weights[token, i] expert_counts[expert_id] = expert_counts[expert_id] + 1 @@ -122,8 +132,10 @@ def workspace_shapes( topk: int, num_experts: int, ) -> Tuple[int, int, torch.dtype]: - max_num_tokens = a.shape[1] if self.max_num_tokens is None else self.max_num_tokens - workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack + max_num_tokens = a.shape[ + 1] if self.max_num_tokens is None else self.max_num_tokens + # TODO: *2 is a hack + workspace13 = num_experts * max_num_tokens * K * topk * 2 workspace2 = max_num_tokens * N return (workspace13, workspace2, a.dtype) @@ -148,16 +160,21 @@ def apply( ) -> torch.Tensor: assert hidden_states.dim() == 3 assert expert_num_tokens is not None - num_tokens = topk_ids.shape[0] - _, tmp_max_num_tokens, K = hidden_states.shape - max_num_tokens = tmp_max_num_tokens if self.max_num_tokens is None else self.max_num_tokens + + if self.max_num_tokens is None: + max_num_tokens = hidden_states.shape[1] + else: + max_num_tokens = self.max_num_tokens + num_experts = global_num_experts - out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) + out = _resize_cache(workspace13, + (num_experts, max_num_tokens, w2.shape[1])) num_local_experts = expert_num_tokens.numel() # TODO: don't need world_size or rank if expert_base always == 0 #assert w1.shape[0] == num_experts, f"{w1.shape} == {num_experts}" - #expert_base = rank_chunk(w1.shape[0], self.rank, self.world_size) * self.rank + #expert_base = rank_chunk(w1.shape[0], self.rank, + # self.world_size) * self.rank expert_base = 0 for expert in range(num_local_experts): @@ -166,10 +183,9 @@ def apply( if num > 0: tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) self.activation( - activation, - tmp, - hidden_states[expert,:num,:] @ w1[expert_base + expert].transpose(0, 1) - ) - out[expert, :num, :] = tmp @ w2[expert_base + expert].transpose(0, 1) + activation, tmp, hidden_states[expert, :num, :] + @ w1[expert_base + expert].transpose(0, 1)) + out[expert, :num, :] = tmp @ w2[expert_base + + expert].transpose(0, 1) return out From 64acde9d8bb889ad28b70c1fed7ca1df8f26f06b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 13:14:05 +0000 Subject: [PATCH 030/171] undo random lint changes Signed-off-by: Bill Nell --- benchmarks/cutlass_benchmarks/w8a8_benchmarks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 09462560f402..e7b742d8bec9 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -11,9 +11,9 @@ import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_tensors from weight_shapes import WEIGHT_SHAPES -from utils import make_rand_tensors from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( w8a8_block_fp8_matmul) From 17e6e00eb2274b227860f64abc88f3b1161fffc6 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 13:34:40 +0000 Subject: [PATCH 031/171] more lint Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 97ecf141851c..f0dabd66feaa 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -6,7 +6,7 @@ import dataclasses import os import traceback -from typing import Callable, Concatenate, Optional, ParamSpec +from typing import Callable, Optional import pytest import torch @@ -16,6 +16,7 @@ nvshmem_init) from torch.multiprocessing import ( spawn) # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import Concatenate, ParamSpec import vllm.model_executor.layers.fused_moe # noqa from vllm.config import VllmConfig, set_current_vllm_config @@ -169,7 +170,7 @@ def torch_dispatch( tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) if max_num_tokens is None: - max_num_tokens = tokens_per_expert.max() + max_num_tokens = int(tokens_per_expert.max().item()) b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), dtype=a.dtype, From 10398519b2a8c6a8aa865cde4b165e6d1571b09a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 13:46:13 +0000 Subject: [PATCH 032/171] more lint nonsense Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index f0dabd66feaa..405ced54d2ee 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -94,7 +94,7 @@ def _worker_parallel_launch( ) except Exception as ex: print(ex) - traceback.print_exception(ex) + traceback.print_exc() raise finally: torch.distributed.destroy_process_group() @@ -176,8 +176,6 @@ def torch_dispatch( dtype=a.dtype, device=a.device) - #print(f"b_a shape {b_a.shape}") - token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) for token in range(num_tokens): From 89de35ffde3fc8771b9d4e6cd92b5dbb6040233d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sat, 15 Mar 2025 01:11:06 +0000 Subject: [PATCH 033/171] WIP torch while Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/cuda_graph_utils.py | 0 vllm/forward_context.py | 3 + vllm/model_executor/layers/fused_moe/layer.py | 74 +++++++++++++++++++ 3 files changed, 77 insertions(+) create mode 100644 vllm/cuda_graph_utils.py diff --git a/vllm/cuda_graph_utils.py b/vllm/cuda_graph_utils.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/forward_context.py b/vllm/forward_context.py index ea3aaa66b4cf..948e63327b09 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -31,6 +31,7 @@ @dataclass class DPMetadata: + max_tokens_across_dp: torch.Tensor cu_tokens_across_dp_cpu: torch.Tensor dp_rank_num_tokens: torch.Tensor @@ -90,6 +91,8 @@ def set_forward_context(attn_metadata: Any, dtype=torch.int32) from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) + #TODO device? + max_tokens_across_dp = torch.max(num_tokens_tensor).to(device="cuda") cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) dp_rank_num_tokens = torch.tensor( [num_tokens], diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3cdf3c97a7d3..9dd4bf410be8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -840,6 +840,80 @@ def forward(self, hidden_states: torch.Tensor, return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) + def forward_impl_while(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + max_tokens_across_dp = get_forward_context( + ).dp_metadata.max_tokens_across_dp + + #TODO: we need to define a couple of ranges: + # 1. the range within this rank's M dimension that we are looping over + # 2. the range within the workspace buffer that our current chunk maps to. + + moe_dp_chunk_size = 256 + my_dp_chunk_size = moe_dp_chunk_size // self.dp_size + chunk_start = torch.tensor(0, device=hidden_states.device) + + def padded_allgather(self, x: torch.Tensor): + assert (len(x.shape) == 2) + buffer = torch.zeros((moe_dp_chunk_size, x.shape[1]), + device=x.device, + dtype=x.dtype) + + buffer[:x.shape[0], :].copy_(x) + get_dp_group().all_gather(buffer, 0) + return buffer + + def cond_fn(chunk_range, max_tokens_across_dp, hidden_states, + router_logits): + return chunk_range[0] < max_tokens_across_dp + + def body_fn(chunk_range, max_tokens_across_dp, full_hidden_states, + full_router_logits): + hidden_states = full_hidden_states[chunk_range] + router_logits = full_router_logits[chunk_range] + + if self.dp_size > 1: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + + hidden_states = self.padded_allgather(hidden_states) + router_logits = self.padded_allgather(router_logits) + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + ) + + if self.dp_size > 1: + all_hidden_states = get_dp_group().all_reduce( + final_hidden_states) + final_hidden_states[chunk_range] = all_hidden_states[ + start:end, :] + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + # Default set to False. (May have to add shared expert outputs.) + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + chunk_range[0] = min(hidden_states.shape[0], + chunk_range[0] + moe_dp_chunk_size) + chunk_range[1] = min(hidden_states.shape[0], + chunk_range[1] + moe_dp_chunk_size) + return chunk_start, hidden_states + def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None From 2c123927760a52a31815a21c8070f487ec270f1a Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 25 Mar 2025 13:10:57 +0000 Subject: [PATCH 034/171] wip Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9dd4bf410be8..88f2ebfaefe8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1006,7 +1006,7 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, self = forward_context.no_compile_layers[layer_name] assert self.quant_method is not None - return self.forward_impl(hidden_states, router_logits) + return self.forward_impl_while(hidden_states, router_logits) def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, From 8c19435481438011d99e66b7cd77c30d217504e0 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 25 Mar 2025 21:32:43 +0000 Subject: [PATCH 035/171] wip Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 88f2ebfaefe8..311590d0c69e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -844,10 +844,19 @@ def forward_impl_while(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): max_tokens_across_dp = get_forward_context( ).dp_metadata.max_tokens_across_dp + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu - #TODO: we need to define a couple of ranges: - # 1. the range within this rank's M dimension that we are looping over - # 2. the range within the workspace buffer that our current chunk maps to. + #In this function we define two ranges: + # 1. chunk_range - The current iteration of the loops's range over the DP world tokens + # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. + + chunk_range = torch.zeros(2, device=hidden_states.device) + chunk_range[1] = min(moe_dp_chunk_size, cu_tokens_across_dp_cpu[-1]) + + my_tokens_in_chunk = torch.zeros(2, device=hidden_states.device) + my_tokens_in_chunk[1] = min(my_dp_chunk_size, + chunk_range[1] - chunk_range[0]) moe_dp_chunk_size = 256 my_dp_chunk_size = moe_dp_chunk_size // self.dp_size From 49d2658f0615c8cf784346c51ac0ef77d88d908e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 27 Mar 2025 13:48:42 +0000 Subject: [PATCH 036/171] wip Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/forward_context.py | 6 +- vllm/model_executor/layers/fused_moe/layer.py | 80 +++++++++---------- 2 files changed, 45 insertions(+), 41 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 948e63327b09..1afdf88ec2da 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -32,6 +32,7 @@ @dataclass class DPMetadata: max_tokens_across_dp: torch.Tensor + num_tokens_across_dp: torch.Tensor cu_tokens_across_dp_cpu: torch.Tensor dp_rank_num_tokens: torch.Tensor @@ -98,7 +99,10 @@ def set_forward_context(attn_metadata: Any, [num_tokens], dtype=torch.uint32, device=vllm_config.device_config.device) - dp_metadata = DPMetadata(cu_tokens_across_dp_cpu, dp_rank_num_tokens) + dp_metadata = DPMetadata(max_tokens_across_dp, + num_tokens_tensor, + cu_tokens_across_dp_cpu, + dp_rank_num_tokens) global _forward_context prev_context = _forward_context diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 311590d0c69e..7af05dc3da6c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -840,53 +840,43 @@ def forward(self, hidden_states: torch.Tensor, return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) - def forward_impl_while(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward_impl_while(self, full_hidden_states: torch.Tensor, + full_router_logits: torch.Tensor): max_tokens_across_dp = get_forward_context( ).dp_metadata.max_tokens_across_dp cu_tokens_across_dp_cpu = get_forward_context( ).dp_metadata.cu_tokens_across_dp_cpu + num_tokens_across_dp = get_forward_context( + ).dp_metadata.num_tokens_across_dp - #In this function we define two ranges: - # 1. chunk_range - The current iteration of the loops's range over the DP world tokens - # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. - - chunk_range = torch.zeros(2, device=hidden_states.device) - chunk_range[1] = min(moe_dp_chunk_size, cu_tokens_across_dp_cpu[-1]) - - my_tokens_in_chunk = torch.zeros(2, device=hidden_states.device) - my_tokens_in_chunk[1] = min(my_dp_chunk_size, - chunk_range[1] - chunk_range[0]) - - moe_dp_chunk_size = 256 - my_dp_chunk_size = moe_dp_chunk_size // self.dp_size - chunk_start = torch.tensor(0, device=hidden_states.device) - - def padded_allgather(self, x: torch.Tensor): + def padded_allgather(x: torch.Tensor): assert (len(x.shape) == 2) buffer = torch.zeros((moe_dp_chunk_size, x.shape[1]), device=x.device, dtype=x.dtype) - buffer[:x.shape[0], :].copy_(x) get_dp_group().all_gather(buffer, 0) return buffer - def cond_fn(chunk_range, max_tokens_across_dp, hidden_states, - router_logits): - return chunk_range[0] < max_tokens_across_dp + #In this function we define two ranges: + # 1. chunk_range - The current iteration of the loops's range over the DP world tokens + # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. + + moe_dp_chunk_size = 256 + moe_dp_chunk_size_per_rank = moe_dp_chunk_size // self.dp_size + + num_tokens_remaining_across_dp = num_tokens_across_dp + chunk_start = 0 + chunk_end = min(moe_dp_chunk_size_per_rank, full_hidden_states.shape[0]) + full_final_hidden_states = torch.empty_like(full_hidden_states) - def body_fn(chunk_range, max_tokens_across_dp, full_hidden_states, - full_router_logits): - hidden_states = full_hidden_states[chunk_range] - router_logits = full_router_logits[chunk_range] + for _ in range(max_tokens_across_dp, moe_dp_chunk_size_per_rank): + hidden_states = full_hidden_states[chunk_start:chunk_end,:] + router_logits = full_router_logits[chunk_start:chunk_end,:] if self.dp_size > 1: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - - hidden_states = self.padded_allgather(hidden_states) - router_logits = self.padded_allgather(router_logits) + hidden_states = padded_allgather(hidden_states) + router_logits = padded_allgather(router_logits) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -906,22 +896,32 @@ def body_fn(chunk_range, max_tokens_across_dp, full_hidden_states, activation=self.activation, ) + cu_tokens_across_dp_this_iter = torch.cumsum( + num_tokens_remaining_across_dp.clamp(max=moe_dp_chunk_size_per_rank), + dim=0) + if self.dp_size > 1: + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[self.dp_rank-1] + end = cu_tokens_across_dp_this_iter[self.dp_rank] + all_hidden_states = get_dp_group().all_reduce( final_hidden_states) - final_hidden_states[chunk_range] = all_hidden_states[ - start:end, :] + final_hidden_states = all_hidden_states[start:end, :] if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + full_final_hidden_states[chunk_start:chunk_end,:].copy_(final_hidden_states) + + num_tokens_remaining_across_dp = torch.clamp(num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, min=0) + chunk_start = min(chunk_start + moe_dp_chunk_size_per_rank, + full_hidden_states.shape[0]) + chunk_end = min(chunk_end + moe_dp_chunk_size_per_rank, + full_hidden_states.shape[0]) + + return full_final_hidden_states - chunk_range[0] = min(hidden_states.shape[0], - chunk_range[0] + moe_dp_chunk_size) - chunk_range[1] = min(hidden_states.shape[0], - chunk_range[1] + moe_dp_chunk_size) - return chunk_start, hidden_states def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): From 36bb8804f40bbe44128c3cbea2d098780130fc7f Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 27 Mar 2025 13:41:18 -0400 Subject: [PATCH 037/171] wip Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 4 +- vllm/model_executor/layers/fused_moe/layer.py | 41 ++++++++----------- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e153e165b621..cacd4ac9ccb4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1462,8 +1462,8 @@ def fused_moe( Defaults to False. - global_num_experts (int): The total number of experts in the global expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert parallel shard. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 7af05dc3da6c..82fe51c1e87f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -840,7 +840,7 @@ def forward(self, hidden_states: torch.Tensor, return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) - def forward_impl_while(self, full_hidden_states: torch.Tensor, + def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): max_tokens_across_dp = get_forward_context( ).dp_metadata.max_tokens_across_dp @@ -849,15 +849,6 @@ def forward_impl_while(self, full_hidden_states: torch.Tensor, num_tokens_across_dp = get_forward_context( ).dp_metadata.num_tokens_across_dp - def padded_allgather(x: torch.Tensor): - assert (len(x.shape) == 2) - buffer = torch.zeros((moe_dp_chunk_size, x.shape[1]), - device=x.device, - dtype=x.dtype) - buffer[:x.shape[0], :].copy_(x) - get_dp_group().all_gather(buffer, 0) - return buffer - #In this function we define two ranges: # 1. chunk_range - The current iteration of the loops's range over the DP world tokens # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. @@ -870,13 +861,18 @@ def padded_allgather(x: torch.Tensor): chunk_end = min(moe_dp_chunk_size_per_rank, full_hidden_states.shape[0]) full_final_hidden_states = torch.empty_like(full_hidden_states) - for _ in range(max_tokens_across_dp, moe_dp_chunk_size_per_rank): + for _ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): hidden_states = full_hidden_states[chunk_start:chunk_end,:] router_logits = full_router_logits[chunk_start:chunk_end,:] - if self.dp_size > 1: - hidden_states = padded_allgather(hidden_states) - router_logits = padded_allgather(router_logits) + cu_tokens_across_dp_this_iter = torch.cumsum( + num_tokens_remaining_across_dp.clamp(max=moe_dp_chunk_size_per_rank), + dim=0) + + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_dp_this_iter) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_dp_this_iter) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -896,10 +892,6 @@ def padded_allgather(x: torch.Tensor): activation=self.activation, ) - cu_tokens_across_dp_this_iter = torch.cumsum( - num_tokens_remaining_across_dp.clamp(max=moe_dp_chunk_size_per_rank), - dim=0) - if self.dp_size > 1: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[self.dp_rank-1] end = cu_tokens_across_dp_this_iter[self.dp_rank] @@ -912,13 +904,14 @@ def padded_allgather(x: torch.Tensor): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - full_final_hidden_states[chunk_start:chunk_end,:].copy_(final_hidden_states) + full_final_hidden_states[chunk_start:chunk_end, :].copy_(final_hidden_states) + # Update bounds num_tokens_remaining_across_dp = torch.clamp(num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, min=0) - chunk_start = min(chunk_start + moe_dp_chunk_size_per_rank, - full_hidden_states.shape[0]) - chunk_end = min(chunk_end + moe_dp_chunk_size_per_rank, - full_hidden_states.shape[0]) + def update_chunk_bound(x: int): + return min(x + moe_dp_chunk_size_per_rank, full_hidden_states.shape[0]) + chunk_start = update_chunk_bound(chunk_start) + chunk_end = update_chunk_bound(chunk_end) return full_final_hidden_states @@ -1015,7 +1008,7 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, self = forward_context.no_compile_layers[layer_name] assert self.quant_method is not None - return self.forward_impl_while(hidden_states, router_logits) + return self.forward_impl_chunked(hidden_states, router_logits) def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, From 09a98133e8ada391b61baa2b3366994921afc68d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 28 Mar 2025 16:35:28 -0400 Subject: [PATCH 038/171] WIP integration Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 82fe51c1e87f..279ba2778b1f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -3,11 +3,14 @@ from abc import abstractmethod from enum import Enum from typing import Callable, List, Optional, Tuple +from dataclasses import dataclass import torch import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter +import pplx_kernels as pplx + import vllm.envs as envs from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, @@ -34,6 +37,24 @@ fused_moe_pallas = None # type: ignore logger = init_logger(__name__) +MOE_DP_CHUNK_SIZE = 256 + +# Adapted from pplx-kernels tests/all_to_all_utils.py +@dataclass +class MoEConfig: + num_experts: int + experts_per_token: int + hidden_dim: int + + num_local_experts: int + dp_size: int + dp_rank: int + ep_size: int + ep_rank: int + + in_dtype: torch.dtype = torch.bfloat16 + out_dtype: torch.dtype = torch.bfloat16 + block_size: int = 128 class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" @@ -71,10 +92,22 @@ def apply( ) -> torch.Tensor: raise NotImplementedError - @CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" + def __init__(self, moe: MoEConfig): + self.all_to_all = pplx.AllToAll( + max_num_tokens=MOE_DP_CHUNK_SIZE // moe.dp_size, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, + rank=moe.ep_rank, + world_size=moe.ep_size, + dp_size=moe.dp_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + hidden_dim_scale_bytes=0, + ) + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -853,8 +886,7 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, # 1. chunk_range - The current iteration of the loops's range over the DP world tokens # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. - moe_dp_chunk_size = 256 - moe_dp_chunk_size_per_rank = moe_dp_chunk_size // self.dp_size + moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size num_tokens_remaining_across_dp = num_tokens_across_dp chunk_start = 0 From af8fd7cf201f09b6f44f57b812f4f0730f2bc4c2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 26 Feb 2025 23:09:34 +0000 Subject: [PATCH 039/171] Add test for deep gemm matmul Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 347 ++++++++++++++++++++++++++++++++ 1 file changed, 347 insertions(+) create mode 100644 tests/kernels/test_block_fp8.py diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py new file mode 100644 index 000000000000..bebc77dcec9e --- /dev/null +++ b/tests/kernels/test_block_fp8.py @@ -0,0 +1,347 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://github.com/sgl-project/sglang/pull/2575 +import deep_gemm + +import itertools +import pytest +import torch + +from typing import Tuple + +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, w8a8_block_fp8_matmul) +from vllm.platforms import current_platform + +if current_platform.get_device_capability() < (9, 0): + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", + allow_module_level=True) + +# Test configurations +DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] +NUM_TOKENS = [7, 83, 2048] +D = [512, 4096, 5120, 13824] +GROUP_SIZE = [64, 128, 256, 512] +M = [1, 7, 83, 512, 2048] +N = [128, 512, 1024, 4096, 7748, 13824] +K = [256, 4096, 5120, 3884, 13824] +# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 +# and its hidden size is 7168. +M_moe = [1, 7, 83, 512, 2048] +N_moe = [4608] # [128, 4608, 13824] +K_moe = [7168] # [256, 7168, 13824] +BLOCK_SIZE = [[128, 128]] +E = [8, 24] # [8, 24, 128, 256] +TOP_KS = [2] # [1, 2, 6] +OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] +SEEDS = [0] + + +def native_per_token_group_quant_fp8(x, + group_size, + eps=1e-10, + dtype=torch.float8_e4m3fn): + """Function to perform per-token-group quantization on an input tensor + `x` using native torch.""" + assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot " + "be divisible by `group_size`") + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_ = x.reshape(x.numel() // group_size, group_size) + amax = x_.abs().max(dim=-1, + keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / fp8_max + x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + + return x_q, x_s + + +def native_w8a8_block_fp8_matmul(A, + B, + As, + Bs, + block_size, + output_dtype=torch.float16): + """Matrix multiplication with block-wise quantization using native torch.""" + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [ + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + ] + B_tiles = [[ + B[ + j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), + ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [ + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + ] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + +def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """Fused moe with block-wise quantization using native torch.""" + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_fp8(a, block_k) + a_q = a_q.to(torch.float32) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = native_w8a8_block_fp8_matmul(a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + block_shape, + output_dtype=a.dtype) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_fp8( + act_out, block_k) + act_out = act_out.to(torch.float32) + out[mask] = native_w8a8_block_fp8_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + block_shape, + output_dtype=a.dtype) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +# Skip all tests if CUDA is not available +pytest.importorskip("torch.cuda") + + +@pytest.fixture(autouse=True) +def setup_cuda(): + torch.set_default_device("cuda") + + +@pytest.mark.parametrize( + "num_tokens,d,dtype,group_size,seed", + itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS)) +@torch.inference_mode() +def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): + torch.manual_seed(seed) + x = torch.rand(num_tokens, d, dtype=dtype) + + ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size) + out, scale = per_token_group_quant_fp8(x, group_size) + + assert torch.allclose(out.to(torch.float32), + ref_out.to(torch.float32), + rtol=0.15) + assert torch.allclose(scale, ref_scale) + + +@pytest.mark.parametrize( + "M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) +@torch.inference_mode() +def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + + ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) + out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.001 + + +@pytest.mark.parametrize( + "M,N,K,E,topk,block_size,dtype,seed", + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS)) +@torch.inference_mode() +def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_bf16 = (torch.rand( + (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w1_bf16 + + w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w2_bf16 + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = torch.rand( + (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale + w2_s = torch.rand( + (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale + + score = torch.randn((M, E), dtype=dtype) + + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) + + print(f"{out.sum()=}") + print(f"{ref_out.sum()=}") + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.03 + + +######################################################################################### + +def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + + +def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((deep_gemm.cell_div(m, 128) * 128, deep_gemm.cell_div(n, 128) * 128), dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + + +@pytest.mark.parametrize( + "M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) +@torch.inference_mode() +def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): + torch.manual_seed(seed) + + # only aligned sizes + if M % 4 != 0 or K % 128 != 0 or N % 64 != 0: + return + + # weird max diff errors + if False and (M == 512 or M == 2048): + return + + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + + ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) + + A_fp8_dg, As_dg = per_token_group_quant_fp8(A_fp32, block_k) + B_fp8_dg, Bs_dg = per_block_cast_to_fp8(B_fp32) + + # Transpose earlier so that the testing will not trigger transposing kernels + As_dg = deep_gemm.get_col_major_tma_aligned_tensor(As_dg) + + out = torch.empty((M, N), device='cuda', dtype=out_dtype) + + assert As_dg.shape == (M, (K + 127) // 128), f"{As_dg.shape} != {(M, (K + 127) // 128)}" + + deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8_dg, As_dg), (B_fp8_dg, Bs_dg), out) + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.001 From 3ab544326c9682e8ac2e535bfb27c8ea8e7e30d3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 27 Feb 2025 03:01:01 +0000 Subject: [PATCH 040/171] fix matmul test Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 171 +++++++++++++++++++++++++++++--- 1 file changed, 157 insertions(+), 14 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index bebc77dcec9e..249da81b32a3 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Adapted from https://github.com/sgl-project/sglang/pull/2575 +# TODO: try/catch this? import deep_gemm import itertools @@ -24,12 +25,14 @@ NUM_TOKENS = [7, 83, 2048] D = [512, 4096, 5120, 13824] GROUP_SIZE = [64, 128, 256, 512] -M = [1, 7, 83, 512, 2048] +#M = [1, 7, 83, 512, 2048] +M = [1, 8, 84, 512, 2048] N = [128, 512, 1024, 4096, 7748, 13824] K = [256, 4096, 5120, 3884, 13824] # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. -M_moe = [1, 7, 83, 512, 2048] +#M_moe = [1, 7, 83, 512, 2048] +M_moe = [1, 8, 84, 512, 2048] N_moe = [4608] # [128, 4608, 13824] K_moe = [7168] # [256, 7168, 13824] BLOCK_SIZE = [[128, 128]] @@ -299,16 +302,11 @@ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): - torch.manual_seed(seed) - # only aligned sizes if M % 4 != 0 or K % 128 != 0 or N % 64 != 0: return - # weird max diff errors - if False and (M == 512 or M == 2048): - return - + torch.manual_seed(seed) factor_for_scale = 1e-2 fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min @@ -323,19 +321,22 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k - As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale - Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + A_fp8_dg, As_dg = per_token_group_quant_fp8(A_fp32, block_k) + B_fp8_dg, Bs_dg = per_block_cast_to_fp8(B_fp32) - ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, + As = As_dg.to(torch.float32) + Bs = Bs_dg.to(torch.float32) + + ref_out = native_w8a8_block_fp8_matmul(A_fp8_dg, B_fp8_dg, As, Bs, block_size, out_dtype) - A_fp8_dg, As_dg = per_token_group_quant_fp8(A_fp32, block_k) - B_fp8_dg, Bs_dg = per_block_cast_to_fp8(B_fp32) + #A_fp8_dg, As_dg = per_token_group_quant_fp8(A_fp32, block_k) + #B_fp8_dg, Bs_dg = per_block_cast_to_fp8(B_fp32) # Transpose earlier so that the testing will not trigger transposing kernels As_dg = deep_gemm.get_col_major_tma_aligned_tensor(As_dg) - out = torch.empty((M, N), device='cuda', dtype=out_dtype) + out = torch.zeros((M, N), device='cuda', dtype=out_dtype) assert As_dg.shape == (M, (K + 127) // 128), f"{As_dg.shape} != {(M, (K + 127) // 128)}" @@ -345,3 +346,145 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.001 + + +################################################################################### + +def construct_grouped( + num_groups: int, + m: int, + k: int, + n: int, + is_masked: bool +) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: + x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16) + y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16) + + assert m % 4 == 0, f'TMA alignment error: {m}' + x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float)) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) + for i in range(num_groups): + x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + + # For non-masked input, we must merge the group and M dims + if not is_masked: + x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1]) + out, ref_out = out.view(-1, n), ref_out.view(-1, n) + + # Transpose earlier so that the testing will not trigger transposing kernels + x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) + return x_fp8, y_fp8, out + + +# ref_out = torch.einsum('gmk,gnk->gmn', x, y) + +from vllm.model_executor.layers.fused_moe import fused_topk, grouped_topk + +def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """Fused moe with block-wise quantization using native torch.""" + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = per_token_group_quant_fp8(a, block_k) + w1, w1_s = per_block_cast_to_fp8(w1) + w2, w2_s = per_block_cast_to_fp8(w2) + + num_groups = w1.shape[0] # ??? + + m_indices = torch.arange(0, num_groups, device=a.device, dtype=torch.int) + m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) + + inter_out = torch.zeros(a_q.shape[0], w1.shape[1], dtype=torch.bfloat16, device=a.device) + + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), + (w1, w1_s), + inter_out, + m_indices) + + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) + + num_groups2 = w2.shape[0] # ??? + + m_indices2 = torch.arange(0, num_groups2, device=a.device, dtype=torch.int) + m_indices2 = m_indices2.unsqueeze(-1).expand(num_groups2, n).contiguous().view(-1) + out = torch.zeros(B * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) + + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((act_out_q, act_out_s), + (w2, w2_s), + out, + m_indices2) + + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.mark.parametrize( + "M,N,K,E,topk,block_size,dtype,seed", + itertools.product(M_moe, N, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS)) +@torch.inference_mode() +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): + + # only aligned sizes + if M % 4 != 0 or K % 128 != 0 or N % 64 != 0: + return + + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_bf16 = (torch.rand( + (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w1_bf16 + + w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w2_bf16 + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = torch.rand( + (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale + w2_s = torch.rand( + (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale + + score = torch.randn((M, E), dtype=dtype) + + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) + + print(f"{out.sum()=}") + print(f"{ref_out.sum()=}") + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.03 From 187eadfd5a18d7bbaccc9ff110c9379c902a1175 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 27 Feb 2025 19:55:59 +0000 Subject: [PATCH 041/171] running Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 59 +++++++++++-------- .../layers/fused_moe/fused_moe.py | 1 + 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 249da81b32a3..1028310b5ca6 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -26,9 +26,15 @@ D = [512, 4096, 5120, 13824] GROUP_SIZE = [64, 128, 256, 512] #M = [1, 7, 83, 512, 2048] -M = [1, 8, 84, 512, 2048] -N = [128, 512, 1024, 4096, 7748, 13824] -K = [256, 4096, 5120, 3884, 13824] + +M = [1, 8, 84, 512, 2048, 4096] +N = [128, 512, 1024, 4096, 7748, 13824, 7168] +K = [256, 4096, 5120, 3884, 13824, 16384] + +#M = [128] +#N = [24576] +#K = [1536] + # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. #M_moe = [1, 7, 83, 512, 2048] @@ -384,46 +390,50 @@ def construct_grouped( def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using native torch.""" - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + M, K = a.shape + print(f"before {a.shape}") + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) + topk_ids = topk_ids.to(dtype=torch.int32).view(-1) _, block_k = block_shape[0], block_shape[1] a_q, a_s = per_token_group_quant_fp8(a, block_k) - w1, w1_s = per_block_cast_to_fp8(w1) - w2, w2_s = per_block_cast_to_fp8(w2) - num_groups = w1.shape[0] # ??? + num_groups = w1.shape[0] + for i in range(num_groups): + w1[i], w1_s[i] = per_block_cast_to_fp8(w1[i].to(dtype=torch.bfloat16)) + w2[i], w2_s[i] = per_block_cast_to_fp8(w2[i].to(dtype=torch.bfloat16)) + + print(f"{M}, {num_groups}, {a.shape}") m_indices = torch.arange(0, num_groups, device=a.device, dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) + m_indices = m_indices.unsqueeze(-1).expand(num_groups, a.shape[0]//num_groups).contiguous().view(-1) inter_out = torch.zeros(a_q.shape[0], w1.shape[1], dtype=torch.bfloat16, device=a.device) + print("FIRST GEMM") + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), inter_out, - m_indices) + topk_ids) act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - num_groups2 = w2.shape[0] # ??? + out = torch.zeros(M * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) - m_indices2 = torch.arange(0, num_groups2, device=a.device, dtype=torch.int) - m_indices2 = m_indices2.unsqueeze(-1).expand(num_groups2, n).contiguous().view(-1) - out = torch.zeros(B * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) + print("SECOND GEMM") deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((act_out_q, act_out_s), (w2, w2_s), out, - m_indices2) + topk_ids) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) @pytest.mark.parametrize( @@ -446,11 +456,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, w1_bf16 = (torch.rand( (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max) del w1_bf16 w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max) del w2_bf16 block_n, block_k = block_size[0], block_size[1] @@ -466,6 +476,12 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, score = torch.randn((M, E), dtype=dtype) + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + + ref_out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) + out = fused_moe( a, w1, @@ -478,9 +494,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, w2_scale=w2_s, block_shape=block_size, ) - ref_out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) - print(f"{out.sum()=}") print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index cacd4ac9ccb4..5c91c47790bf 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -503,6 +503,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, M = A.shape[0] num_tokens = M * top_k + # EM = num_groups EM = sorted_token_ids.shape[0] if A.shape[0] < config["BLOCK_SIZE_M"]: # optimize for small batch_size. From 45fd37fcc4b4b4d3fdedaa0766981415c485fe53 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 27 Feb 2025 20:54:25 +0000 Subject: [PATCH 042/171] wip Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 143 ++++++++++++++++---------------- 1 file changed, 70 insertions(+), 73 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 1028310b5ca6..7a9a46291ae9 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -10,6 +10,7 @@ from typing import Tuple +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -292,12 +293,12 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) -def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def per_block_cast_to_fp8(x: torch.Tensor, block_size_n: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((deep_gemm.cell_div(m, 128) * 128, deep_gemm.cell_div(n, 128) * 128), dtype=x.dtype, device=x.device) + x_padded = torch.zeros((deep_gemm.cell_div(m, 128) * 128, deep_gemm.cell_div(n, block_size_n) * block_size_n), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) @@ -388,32 +389,40 @@ def construct_grouped( from vllm.model_executor.layers.fused_moe import fused_topk, grouped_topk -def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): +def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, score, topk, block_shape): """Fused moe with block-wise quantization using native torch.""" - M, K = a.shape - print(f"before {a.shape}") - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.to(dtype=torch.int32).view(-1) - _, block_k = block_shape[0], block_shape[1] + M, K = a.shape + N = w2.shape[-1] + num_groups = w1.shape[0] + + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + + block_n, block_k = block_shape[0], block_shape[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = torch.empty((num_groups, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + w2_s = torch.empty((num_groups, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + a_q, a_s = per_token_group_quant_fp8(a, block_k) - num_groups = w1.shape[0] for i in range(num_groups): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1[i].to(dtype=torch.bfloat16)) + w1[i], w1_s[i] = per_block_cast_to_fp8(w1[i].to(dtype=torch.bfloat16), block_n) w2[i], w2_s[i] = per_block_cast_to_fp8(w2[i].to(dtype=torch.bfloat16)) - print(f"{M}, {num_groups}, {a.shape}") + inter_out = torch.empty(a_q.shape[0], w1.shape[1], dtype=torch.bfloat16, device=a.device) - m_indices = torch.arange(0, num_groups, device=a.device, dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand(num_groups, a.shape[0]//num_groups).contiguous().view(-1) + #print("FIRST GEMM") - inter_out = torch.zeros(a_q.shape[0], w1.shape[1], dtype=torch.bfloat16, device=a.device) - - print("FIRST GEMM") + w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() + w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), @@ -425,7 +434,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape out = torch.zeros(M * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) - print("SECOND GEMM") + #print("SECOND GEMM") deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((act_out_q, act_out_s), (w2, w2_s), @@ -433,7 +442,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape topk_ids) return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1), w1_s, w2_s @pytest.mark.parametrize( @@ -444,60 +453,48 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): # only aligned sizes - if M % 4 != 0 or K % 128 != 0 or N % 64 != 0: + if M % 4 != 0 or K % 128 != 0 or N % 128 != 0: return - torch.manual_seed(seed) - factor_for_scale = 1e-2 - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_bf16 = (torch.rand( - (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max) - del w1_bf16 - - w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max) - del w2_bf16 - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = (2 * N + block_n - 1) // block_n - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - k_tiles_w2 = (N + block_k - 1) // block_k - - w1_s = torch.rand( - (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale - w2_s = torch.rand( - (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale - - score = torch.randn((M, E), dtype=dtype) - - w1 = w1.to(torch.float8_e4m3fn) - w2 = w2.to(torch.float8_e4m3fn) - - ref_out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) - - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - print(f"{out.sum()=}") - print(f"{ref_out.sum()=}") - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.03 + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_bf16 = (torch.rand( + (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w1_bf16 + + w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w2_bf16 + + score = torch.randn((M, E), dtype=dtype) + + # TODO: move out scale setup + ref_out, w1_s, w2_s = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, score, topk, block_size) + + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + #print(f"{out.sum()=}") + #print(f"{ref_out.sum()=}") + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.03 From 1c54fa96a909feb3630608254d17e4ee38bb2eda Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 28 Feb 2025 04:23:16 +0000 Subject: [PATCH 043/171] wip Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 290 +++++++++++++++++--------------- 1 file changed, 151 insertions(+), 139 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 7a9a46291ae9..97b99445536c 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -2,14 +2,13 @@ # Adapted from https://github.com/sgl-project/sglang/pull/2575 # TODO: try/catch this? -import deep_gemm - import itertools +from typing import Tuple + +import deep_gemm import pytest import torch -from typing import Tuple - from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe @@ -43,7 +42,8 @@ N_moe = [4608] # [128, 4608, 13824] K_moe = [7168] # [256, 7168, 13824] BLOCK_SIZE = [[128, 128]] -E = [8, 24] # [8, 24, 128, 256] +#E = [8, 24] # [8, 24, 128, 256] +E = [8, 16] # [8, 24, 128, 256] TOP_KS = [2] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -285,23 +285,33 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ######################################################################################### -def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + +def per_token_cast_to_fp8( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 m, n = x.shape x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to( + torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) -def per_block_cast_to_fp8(x: torch.Tensor, block_size_n: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: +def per_block_cast_to_fp8( + x: torch.Tensor, + block_size_n: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((deep_gemm.cell_div(m, 128) * 128, deep_gemm.cell_div(n, block_size_n) * block_size_n), dtype=x.dtype, device=x.device) + x_padded = torch.zeros( + (deep_gemm.cell_div(m, 128) * 128, + deep_gemm.cell_div(n, block_size_n) * block_size_n), + dtype=x.dtype, + device=x.device) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( + x_amax / 448.0).view(x_view.size(0), x_view.size(2)) @pytest.mark.parametrize( @@ -314,40 +324,32 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): return torch.manual_seed(seed) - factor_for_scale = 1e-2 fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min + fp8_max = fp8_info.max A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - block_n, block_k = block_size[0], block_size[1] - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k + _, block_k = block_size[0], block_size[1] - A_fp8_dg, As_dg = per_token_group_quant_fp8(A_fp32, block_k) - B_fp8_dg, Bs_dg = per_block_cast_to_fp8(B_fp32) + A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k) + B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32) - As = As_dg.to(torch.float32) - Bs = Bs_dg.to(torch.float32) + As = As_fp8.to(torch.float32) + Bs = Bs_fp8.to(torch.float32) - ref_out = native_w8a8_block_fp8_matmul(A_fp8_dg, B_fp8_dg, As, Bs, block_size, + ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - #A_fp8_dg, As_dg = per_token_group_quant_fp8(A_fp32, block_k) - #B_fp8_dg, Bs_dg = per_block_cast_to_fp8(B_fp32) - # Transpose earlier so that the testing will not trigger transposing kernels - As_dg = deep_gemm.get_col_major_tma_aligned_tensor(As_dg) + As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8) out = torch.zeros((M, N), device='cuda', dtype=out_dtype) - assert As_dg.shape == (M, (K + 127) // 128), f"{As_dg.shape} != {(M, (K + 127) // 128)}" + assert As_fp8.shape == (M, (K + 127) // + 128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}" - deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8_dg, As_dg), (B_fp8_dg, Bs_dg), out) + deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / @@ -357,144 +359,154 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): ################################################################################### -def construct_grouped( - num_groups: int, - m: int, - k: int, - n: int, - is_masked: bool -) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: - x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16) - y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) - out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16) - - assert m % 4 == 0, f'TMA alignment error: {m}' - x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float)) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) - for i in range(num_groups): - x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) - y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) - - # For non-masked input, we must merge the group and M dims - if not is_masked: - x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1]) - out, ref_out = out.view(-1, n), ref_out.view(-1, n) - - # Transpose earlier so that the testing will not trigger transposing kernels - x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) - return x_fp8, y_fp8, out - - # ref_out = torch.einsum('gmk,gnk->gmn', x, y) -from vllm.model_executor.layers.fused_moe import fused_topk, grouped_topk -def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, score, topk, block_shape): - """Fused moe with block-wise quantization using native torch.""" +def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_shape): + """Fused moe with block-wise quantization using DeepGemm torch.""" + M = a.numel() // a.shape[-1] + K = w1.shape[-1] + num_groups = w1.shape[0] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + inter_out = torch.zeros(a.shape[0], + w1.shape[1], + dtype=torch.bfloat16, + device=a.device) score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.to(dtype=torch.int32).view(-1) - - M, K = a.shape - N = w2.shape[-1] - num_groups = w1.shape[0] - - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - - block_n, block_k = block_shape[0], block_shape[1] - n_tiles_w1 = (2 * N + block_n - 1) // block_n - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - k_tiles_w2 = (N + block_k - 1) // block_k - - w1_s = torch.empty((num_groups, n_tiles_w1, k_tiles_w1), dtype=torch.float32) - w2_s = torch.empty((num_groups, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + topk_ids = topk_ids.view(-1) + _, block_k = block_shape[0], block_shape[1] a_q, a_s = per_token_group_quant_fp8(a, block_k) - for i in range(num_groups): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1[i].to(dtype=torch.bfloat16), block_n) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2[i].to(dtype=torch.bfloat16)) + #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) + #print(f"FIRST GEMM {a_q.shape}") - inter_out = torch.empty(a_q.shape[0], w1.shape[1], dtype=torch.bfloat16, device=a.device) + m_indices = torch.arange(0, num_groups, dtype=torch.int) + m_indices = m_indices.unsqueeze(-1).expand( + num_groups, (2 * M) // num_groups).contiguous().view(-1) + #print(f"m_indices {m_indices.shape}, ng={num_groups}") - #print("FIRST GEMM") - - w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() - w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() - - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), - (w1, w1_s), - inter_out, - topk_ids) + if True: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (a_q, a_s), (w1, w1_s), inter_out, m_indices) + else: + topk_ids = topk_ids.to(dtype=torch.int32) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a_q, a_s), (w1, w1_s), + inter_out, topk_ids, M) act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - out = torch.zeros(M * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) + out = torch.zeros(M * topk, + w2.shape[1], + dtype=torch.bfloat16, + device=a.device) #print("SECOND GEMM") - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((act_out_q, act_out_s), - (w2, w2_s), - out, - topk_ids) + if True: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (act_out_q, act_out_s), (w2, w2_s), out, m_indices) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( + (act_out_q, act_out_s), (w2, w2_s), out, topk_ids, M) return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1), w1_s, w2_s + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) + itertools.product(M_moe, N, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, + dtype, seed): # only aligned sizes - if M % 4 != 0 or K % 128 != 0 or N % 128 != 0: + if (M % 4 != 0 or N % 128 != 0 or K % 128 != 0): return vllm_config = VllmConfig() + + torch.manual_seed(seed) + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * + fp8_max).clamp(min=fp8_min, max=fp8_max) + + w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * + fp8_max).clamp(min=fp8_min, max=fp8_max) + + score = torch.randn((M, E), dtype=dtype) + + num_groups = E + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = ((2 * N) + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w2 = (N + block_k - 1) // block_k + + w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) + w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) + + w1_s = torch.empty((num_groups, n_tiles_w1, k_tiles_w1), + dtype=torch.float32) + w2_s = torch.empty((num_groups, n_tiles_w2, k_tiles_w2), + dtype=torch.float32) + + assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) + assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] + for i in range(num_groups): + w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) + w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + + w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() + w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() + with set_current_vllm_config(vllm_config): - torch.manual_seed(seed) - factor_for_scale = 1e-2 - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_bf16 = (torch.rand( - (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - del w1_bf16 - - w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - del w2_bf16 - - score = torch.randn((M, E), dtype=dtype) - - # TODO: move out scale setup - ref_out, w1_s, w2_s = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, score, topk, block_size) - - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - #print(f"{out.sum()=}") - #print(f"{ref_out.sum()=}") - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.03 + if False: + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) + else: + out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) + + ref_out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + + #print(f"{out.sum()=}") + #print(f"{ref_out.sum()=}") + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.03 From 8752d63ceeb743854a5a99a6f246a384fb95f1a6 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 28 Feb 2025 22:03:27 +0000 Subject: [PATCH 044/171] debugging Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 99 ++++++++++++++++++++++----------- 1 file changed, 68 insertions(+), 31 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 97b99445536c..0093d74efa70 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform +from vllm.utils import cdiv if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", @@ -223,11 +224,13 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.001 +def p(s, t): + print(f"{s}: {t.shape}, {t.dtype}") @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) + #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product([4], [128], [128], [8], [2], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): torch.manual_seed(seed) @@ -235,6 +238,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min + vllm_config = VllmConfig() + a = torch.randn((M, K), dtype=dtype) / 10 w1_bf16 = (torch.rand( @@ -259,20 +264,27 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) + p("a", a) + p("w1", w1) + p("w1_s", w1_s) + p("w2", w2) + p("w2_s", w2_s) + + with set_current_vllm_config(vllm_config): + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) print(f"{out.sum()=}") print(f"{ref_out.sum()=}") @@ -310,8 +322,9 @@ def per_block_cast_to_fp8( x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( - x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() + scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled_sub, scales @pytest.mark.parametrize( @@ -369,7 +382,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, K = w1.shape[-1] num_groups = w1.shape[0] a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - inter_out = torch.zeros(a.shape[0], + inter_out = torch.empty(a.shape[0], w1.shape[1], dtype=torch.bfloat16, device=a.device) @@ -386,8 +399,8 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, m_indices = torch.arange(0, num_groups, dtype=torch.int) m_indices = m_indices.unsqueeze(-1).expand( - num_groups, (2 * M) // num_groups).contiguous().view(-1) - #print(f"m_indices {m_indices.shape}, ng={num_groups}") + num_groups, max(M // num_groups, 1)).contiguous().view(-1) + p("m_indices", m_indices) if True: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( @@ -400,13 +413,13 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - out = torch.zeros(M * topk, + #print("SECOND GEMM") + + out = torch.empty(act_out.shape[0], w2.shape[1], dtype=torch.bfloat16, device=a.device) - #print("SECOND GEMM") - if True: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) @@ -420,13 +433,15 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): # only aligned sizes - if (M % 4 != 0 or N % 128 != 0 or K % 128 != 0): + if (N % 128 != 0 or K % 128 != 0): + print(f"skip {N}, {K}") return vllm_config = VllmConfig() @@ -460,14 +475,35 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_s = torch.empty((num_groups, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + print(f"NUM_GROUPS = {num_groups}") + p("before w1_s", w1_s) + p("before w2_s", w2_s) + assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] for i in range(num_groups): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() - w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() + p("imm w1_s", w1_s) + + w1_sa = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() + w2_sa = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() + + if w1_sa.shape != w1_s.shape or w2_sa.shape != w2_s.shape: + p("w1_sa", w1_sa) + p("w2_sa", w2_sa) + print(f"UNALIGNED") + return + + w1_s = w1_sa + w2_s = w2_sa + + p("a", a) + p("w1", w1) + p("final w1_s", w1_s) + p("w2", w2) + p("w2_s", w2_s) with set_current_vllm_config(vllm_config): if False: @@ -487,9 +523,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) else: - out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) - ref_out = fused_moe( a, w1, @@ -503,6 +536,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) + out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) + + #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") From 9ac9bfe0329b0ede26f8e950845a6a6fb3414441 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 28 Feb 2025 22:04:31 +0000 Subject: [PATCH 045/171] debugging Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 5c91c47790bf..6a8929e2a5ee 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1360,6 +1360,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) + print(intermediate_cache2) + if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) From 2724f059c74bbea63f25643272b0a09c748e92b5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 28 Feb 2025 22:04:45 +0000 Subject: [PATCH 046/171] fix Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6a8929e2a5ee..5c91c47790bf 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1360,8 +1360,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - print(intermediate_cache2) - if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) From e86cf1d32add3d1b9ea1eb2d4ed42607da7cddcf Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 28 Feb 2025 23:41:36 +0000 Subject: [PATCH 047/171] update deep gemm Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 30 ++++++++++++------- .../layers/fused_moe/fused_moe.py | 2 ++ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 0093d74efa70..ea96724f5590 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -229,8 +229,7 @@ def p(s, t): @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product([4], [128], [128], [8], [2], [[128, 128]], DTYPES, SEEDS)) + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): torch.manual_seed(seed) @@ -314,8 +313,8 @@ def per_block_cast_to_fp8( assert x.dim() == 2 m, n = x.shape x_padded = torch.zeros( - (deep_gemm.cell_div(m, 128) * 128, - deep_gemm.cell_div(n, block_size_n) * block_size_n), + (deep_gemm.ceil_div(m, 128) * 128, + deep_gemm.ceil_div(n, block_size_n) * block_size_n), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x @@ -334,7 +333,7 @@ def per_block_cast_to_fp8( def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # only aligned sizes if M % 4 != 0 or K % 128 != 0 or N % 64 != 0: - return + pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") torch.manual_seed(seed) fp8_info = torch.finfo(torch.float8_e4m3fn) @@ -399,8 +398,15 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, m_indices = torch.arange(0, num_groups, dtype=torch.int) m_indices = m_indices.unsqueeze(-1).expand( - num_groups, max(M // num_groups, 1)).contiguous().view(-1) + num_groups, max((topk * M) // num_groups, 1)).contiguous().view(-1) + #m_indices = torch.IntTensor([0, 1]) p("m_indices", m_indices) + print(m_indices) + + print("topk", topk_ids) + print(topk_ids) + print("topk_weight", topk_weight) + print(topk_weight) if True: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( @@ -410,6 +416,8 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a_q, a_s), (w1, w1_s), inter_out, topk_ids, M) + print(f"DG {inter_out.shape} {inter_out}") + act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) @@ -441,8 +449,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, # only aligned sizes if (N % 128 != 0 or K % 128 != 0): - print(f"skip {N}, {K}") - return + pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") + + torch.set_printoptions(profile="full") vllm_config = VllmConfig() @@ -490,11 +499,12 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1_sa = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() w2_sa = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() + # TODO: move size alignment further up when setting up all shapes if w1_sa.shape != w1_s.shape or w2_sa.shape != w2_s.shape: p("w1_sa", w1_sa) p("w2_sa", w2_sa) - print(f"UNALIGNED") - return + print("UNALIGNED") + pytest.skip("UNALIGNED") w1_s = w1_sa w2_s = w2_sa diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 5c91c47790bf..f3e8c59c3612 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1360,6 +1360,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) + print(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") + if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) From 353687d982249879bdf4c4abbd1617ac5b3c8f77 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 1 Mar 2025 00:21:16 +0000 Subject: [PATCH 048/171] update deep gemm + small test case Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index ea96724f5590..cc2d1d8673f0 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform -from vllm.utils import cdiv if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", @@ -224,12 +223,15 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.001 + def p(s, t): print(f"{s}: {t.shape}, {t.dtype}") + @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): torch.manual_seed(seed) @@ -399,7 +401,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, m_indices = torch.arange(0, num_groups, dtype=torch.int) m_indices = m_indices.unsqueeze(-1).expand( num_groups, max((topk * M) // num_groups, 1)).contiguous().view(-1) - #m_indices = torch.IntTensor([0, 1]) + #m_indices = torch.IntTensor([1, 0]).to(dtype=torch.int32, device=a.device) p("m_indices", m_indices) print(m_indices) @@ -442,7 +444,8 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, SEEDS)) + itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, + SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -485,8 +488,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype=torch.float32) print(f"NUM_GROUPS = {num_groups}") - p("before w1_s", w1_s) - p("before w2_s", w2_s) assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] @@ -494,8 +495,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - p("imm w1_s", w1_s) - w1_sa = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() w2_sa = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() @@ -511,7 +510,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, p("a", a) p("w1", w1) - p("final w1_s", w1_s) + #print(w1) + p("w1_s", w1_s) + #print(w1_s) p("w2", w2) p("w2_s", w2_s) @@ -549,7 +550,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) - #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") From 228c054292e91c73ed7ade44b5e4f1f76390fa20 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 2 Mar 2025 20:28:35 +0000 Subject: [PATCH 049/171] wip Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 81 +++++++++++++++++++++++++-------- 1 file changed, 62 insertions(+), 19 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index cc2d1d8673f0..cdb4b601a1cc 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -295,10 +295,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.03 - ######################################################################################### - def per_token_cast_to_fp8( x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 @@ -375,16 +373,50 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # ref_out = torch.einsum('gmk,gnk->gmn', x, y) +def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """Fused moe with block-wise quantization using native torch.""" + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = per_token_group_quant_fp8(a, block_k) + a_q = a_q.to(dtype=torch.float32) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = torch.empty((a_q[mask].shape[0], w1[i].shape[0]), + device=a_q.device, dtype=torch.bfloat16) + deep_gemm.gemm_fp8_fp8_bf16_nt((a_q[mask].to(dtype=torch.float8_e4m3fn), a_s[mask]), + (w1[i], w1_s[i]), + inter_out) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) + tmp_out = torch.empty((act_out.shape[0], w2[i].shape[0]), + device=a_q.device, dtype=torch.bfloat16) + deep_gemm.gemm_fp8_fp8_bf16_nt((act_out_q, act_out_s), + (w2[i], w2_s[i]), + tmp_out) + out[mask] = tmp_out + + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" - M = a.numel() // a.shape[-1] - K = w1.shape[-1] num_groups = w1.shape[0] + M = a.numel() // a.shape[-1] # * num_groups) + M_sum = M # * num_groups + K = w1.shape[-1] a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - inter_out = torch.empty(a.shape[0], - w1.shape[1], + inter_out = torch.empty((M_sum, K), dtype=torch.bfloat16, device=a.device) score = torch.softmax(score, dim=-1, dtype=torch.float32) @@ -392,8 +424,15 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + print(f"BLOCK_M {block_m}") + p("A", a) + _, block_k = block_shape[0], block_shape[1] - a_q, a_s = per_token_group_quant_fp8(a, block_k) + a_q, a_s = per_token_group_quant_fp8(a, block_m) + + p("A_q", a_q) + p("A_s", a_s) #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) #print(f"FIRST GEMM {a_q.shape}") @@ -437,8 +476,8 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( (act_out_q, act_out_s), (w2, w2_s), out, topk_ids, M) - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + return (out.view(M_sum, -1, w2.shape[1]) * + topk_weight.view(M_sum, -1, 1).to(out.dtype)).sum(dim=1) @pytest.mark.parametrize( @@ -479,18 +518,22 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, n_tiles_w2 = (K + block_n - 1) // block_n k_tiles_w2 = (N + block_k - 1) // block_k - w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) - w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) + # TODO: turn these back to empty calls + w1 = torch.zeros_like(w1_bf16, dtype=torch.float8_e4m3fn) + w2 = torch.zeros_like(w2_bf16, dtype=torch.float8_e4m3fn) - w1_s = torch.empty((num_groups, n_tiles_w1, k_tiles_w1), + w1_s = torch.zeros((num_groups, n_tiles_w1, k_tiles_w1), dtype=torch.float32) - w2_s = torch.empty((num_groups, n_tiles_w2, k_tiles_w2), + w2_s = torch.zeros((num_groups, n_tiles_w2, k_tiles_w2), dtype=torch.float32) print(f"NUM_GROUPS = {num_groups}") assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] + + # TODO: fix later + print("For now, only convert the first group, the rest will be 0") for i in range(num_groups): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) @@ -517,7 +560,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, p("w2_s", w2_s) with set_current_vllm_config(vllm_config): - if False: + if True: out = fused_moe( a, w1, @@ -531,9 +574,12 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) + ref_out = deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) else: + out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) + ref_out = fused_moe( a, w1, @@ -547,9 +593,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) - #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") From 4439c534fef17935ccb5fbe4b849504770df42e5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 2 Mar 2025 20:40:35 +0000 Subject: [PATCH 050/171] wip Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index cdb4b601a1cc..d63bbd2e1bb3 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -412,9 +412,11 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" num_groups = w1.shape[0] - M = a.numel() // a.shape[-1] # * num_groups) - M_sum = M # * num_groups - K = w1.shape[-1] + M = a.shape[0] + M_sum = M * topk + N = w1.shape[1] // 2 + K = w1.shape[2] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) inter_out = torch.empty((M_sum, K), dtype=torch.bfloat16, @@ -437,10 +439,15 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) #print(f"FIRST GEMM {a_q.shape}") - m_indices = torch.arange(0, num_groups, dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand( - num_groups, max((topk * M) // num_groups, 1)).contiguous().view(-1) - #m_indices = torch.IntTensor([1, 0]).to(dtype=torch.int32, device=a.device) + # use topk_ids?? + if True: + m_indices = torch.arange(0, num_groups, dtype=torch.int) + m_indices = m_indices.unsqueeze(-1).expand( + num_groups, max(M_sum // num_groups, 1)).contiguous().view(-1) + #m_indices = torch.IntTensor([1, 0]).to(dtype=torch.int32, device=a.device) + else: + pass + p("m_indices", m_indices) print(m_indices) @@ -560,7 +567,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, p("w2_s", w2_s) with set_current_vllm_config(vllm_config): - if True: + if False: out = fused_moe( a, w1, From 487e3196bc5ee8355e09ee203c0a80b3241c23b2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 2 Mar 2025 22:52:51 +0000 Subject: [PATCH 051/171] problem with scores Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 43 ++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index d63bbd2e1bb3..8f63f16f3328 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -12,6 +12,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -414,9 +415,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, num_groups = w1.shape[0] M = a.shape[0] M_sum = M * topk - N = w1.shape[1] // 2 - K = w1.shape[2] - + K = w1.shape[2] # w2.shape[1] a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) inter_out = torch.empty((M_sum, K), dtype=torch.bfloat16, @@ -430,28 +429,31 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, print(f"BLOCK_M {block_m}") p("A", a) + row_size = max(M_sum // num_groups, 1) + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(topk_ids, row_size, num_groups, None) + ) + m_indices = expert_ids + assert m_indices.numel() == M_sum + print(f"num_tokens_post_padded = {num_tokens_post_padded}") + p("expert ids", expert_ids) + _, block_k = block_shape[0], block_shape[1] a_q, a_s = per_token_group_quant_fp8(a, block_m) - p("A_q", a_q) - p("A_s", a_s) - #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) #print(f"FIRST GEMM {a_q.shape}") - # use topk_ids?? - if True: - m_indices = torch.arange(0, num_groups, dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand( - num_groups, max(M_sum // num_groups, 1)).contiguous().view(-1) - #m_indices = torch.IntTensor([1, 0]).to(dtype=torch.int32, device=a.device) - else: - pass + # m_indices maps to expert_ids + #m_indices = torch.arange(0, num_groups, dtype=torch.int) + #m_indices = m_indices.unsqueeze(-1).expand( + # num_groups, row_size).contiguous().view(-1) p("m_indices", m_indices) print(m_indices) - print("topk", topk_ids) + print("topk_ids", topk_ids) print(topk_ids) print("topk_weight", topk_weight) print(topk_weight) @@ -483,8 +485,8 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( (act_out_q, act_out_s), (w2, w2_s), out, topk_ids, M) - return (out.view(M_sum, -1, w2.shape[1]) * - topk_weight.view(M_sum, -1, 1).to(out.dtype)).sum(dim=1) + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) @pytest.mark.parametrize( @@ -516,7 +518,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) - score = torch.randn((M, E), dtype=dtype) + #score = torch.randn((M, E), dtype=dtype) + score = torch.zeros((M, E), dtype=dtype) num_groups = E block_n, block_k = block_size[0], block_size[1] @@ -600,8 +603,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - #print(f"{out.sum()=}") - #print(f"{ref_out.sum()=}") + print(f"{out.sum()=}") + print(f"{ref_out.sum()=}") rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / From 8d89dc2df4083da81780b9f6337d0c46118d66c1 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Mar 2025 19:39:40 +0000 Subject: [PATCH 052/171] some passing tests Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 81 +++++++++++-------- .../layers/fused_moe/fused_moe.py | 2 +- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 8f63f16f3328..2b625b838e8a 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -39,13 +39,13 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. #M_moe = [1, 7, 83, 512, 2048] -M_moe = [1, 8, 84, 512, 2048] -N_moe = [4608] # [128, 4608, 13824] -K_moe = [7168] # [256, 7168, 13824] +M_moe = [1, 2, 8, 84, 512] #, 2048] +N_moe = [128, 256, 4608] # [128, 4608, 13824] +K_moe = [256, 512, 7168] # [256, 7168, 13824] BLOCK_SIZE = [[128, 128]] #E = [8, 24] # [8, 24, 128, 256] -E = [8, 16] # [8, 24, 128, 256] -TOP_KS = [2] # [1, 2, 6] +E = [2] #, 8] #, 16] # [8, 24, 128, 256] +TOP_KS = [1] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -227,7 +227,11 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): def p(s, t): print(f"{s}: {t.shape}, {t.dtype}") + pass +def pp(x): + print(x) + pass @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", @@ -413,11 +417,10 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" num_groups = w1.shape[0] - M = a.shape[0] - M_sum = M * topk - K = w1.shape[2] # w2.shape[1] + M, K = a.shape + N = w2.shape[-1] a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - inter_out = torch.empty((M_sum, K), + inter_out = torch.empty((a.shape[0], w1[0].shape[0]), dtype=torch.bfloat16, device=a.device) score = torch.softmax(score, dim=-1, dtype=torch.float32) @@ -426,18 +429,18 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, topk_ids = topk_ids.view(-1) block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - print(f"BLOCK_M {block_m}") + pp(f"BLOCK_M {block_m}") p("A", a) - row_size = max(M_sum // num_groups, 1) + row_size = max((topk * M) // num_groups, 1) # 2 *? sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(topk_ids, row_size, num_groups, None) + moe_align_block_size(topk_ids, M * topk, num_groups, None) ) m_indices = expert_ids - assert m_indices.numel() == M_sum - print(f"num_tokens_post_padded = {num_tokens_post_padded}") - p("expert ids", expert_ids) + #assert m_indices.numel() == num_groups * M * topk + #pp(f"num_tokens_post_padded = {num_tokens_post_padded}") + #p("expert ids", expert_ids) _, block_k = block_shape[0], block_shape[1] a_q, a_s = per_token_group_quant_fp8(a, block_m) @@ -446,17 +449,16 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, #print(f"FIRST GEMM {a_q.shape}") # m_indices maps to expert_ids - #m_indices = torch.arange(0, num_groups, dtype=torch.int) - #m_indices = m_indices.unsqueeze(-1).expand( - # num_groups, row_size).contiguous().view(-1) - + m_indices = torch.arange(0, M, dtype=torch.int) + m_indices = m_indices.unsqueeze(-1).expand(M, topk).contiguous().view(-1) p("m_indices", m_indices) - print(m_indices) + pp(m_indices) + p("topk_ids", topk_ids) + #pp(topk_ids) + p("topk_weight", topk_weight) + #pp(topk_weight) - print("topk_ids", topk_ids) - print(topk_ids) - print("topk_weight", topk_weight) - print(topk_weight) + pp("FIRST GEMM") if True: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( @@ -466,12 +468,14 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a_q, a_s), (w1, w1_s), inter_out, topk_ids, M) - print(f"DG {inter_out.shape} {inter_out}") + pp("FIRST GEMM DONE") + + #pp(f"DG {inter_out.shape} {inter_out}") act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - #print("SECOND GEMM") + pp("SECOND GEMM") out = torch.empty(act_out.shape[0], w2.shape[1], @@ -485,15 +489,16 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( (act_out_q, act_out_s), (w2, w2_s), out, topk_ids, M) + pp("SECOND GEMM DONE") + return (out.view(M, -1, w2.shape[1]) * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, - SEEDS)) + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -502,6 +507,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, if (N % 128 != 0 or K % 128 != 0): pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") + pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") + torch.set_printoptions(profile="full") vllm_config = VllmConfig() @@ -519,7 +526,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, fp8_max).clamp(min=fp8_min, max=fp8_max) #score = torch.randn((M, E), dtype=dtype) + if False: + score = torch.empty((M, E), dtype=dtype) + for i in range(M): + score[i] = torch.full((E,), 1.0/(i+1), dtype=dtype) + for i in range(score.numel()): + score.view(-1)[i] = 1.0/(i+1) score = torch.zeros((M, E), dtype=dtype) + p("score", score) + #pp(score) num_groups = E block_n, block_k = block_size[0], block_size[1] @@ -537,13 +552,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_s = torch.zeros((num_groups, n_tiles_w2, k_tiles_w2), dtype=torch.float32) - print(f"NUM_GROUPS = {num_groups}") - assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] # TODO: fix later - print("For now, only convert the first group, the rest will be 0") + pp("For now, only convert the first group, the rest will be 0") for i in range(num_groups): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) @@ -603,8 +616,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - print(f"{out.sum()=}") - print(f"{ref_out.sum()=}") + #print(f"{out.sum()=}") + #print(f"{ref_out.sum()=}") rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f3e8c59c3612..9a0c56c75ac7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1360,7 +1360,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - print(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") + #print(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, From abf6171662f98b9aa4a36fc1095f4e1ace2747fd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Mar 2025 19:40:04 +0000 Subject: [PATCH 053/171] some passing tests Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 2b625b838e8a..42709535fea4 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -45,7 +45,7 @@ BLOCK_SIZE = [[128, 128]] #E = [8, 24] # [8, 24, 128, 256] E = [2] #, 8] #, 16] # [8, 24, 128, 256] -TOP_KS = [1] # [1, 2, 6] +TOP_KS = [2] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] From c09f42fa3b9c3f825d9ad2d7b1d85f2fb146163a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Mar 2025 19:46:48 +0000 Subject: [PATCH 054/171] topk > 1 doesn't work. prune oom-ing tests Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 42709535fea4..308956678e18 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -45,7 +45,7 @@ BLOCK_SIZE = [[128, 128]] #E = [8, 24] # [8, 24, 128, 256] E = [2] #, 8] #, 16] # [8, 24, 128, 256] -TOP_KS = [2] # [1, 2, 6] +TOP_KS = [1] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -495,6 +495,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) +# topk > 1 does not work @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) From 2ffac311e51b22e2b1ac3faee1c53746fb1c3e3b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Mar 2025 20:07:51 +0000 Subject: [PATCH 055/171] fix indices Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 308956678e18..1c5f9c2ce645 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -449,8 +449,8 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, #print(f"FIRST GEMM {a_q.shape}") # m_indices maps to expert_ids - m_indices = torch.arange(0, M, dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand(M, topk).contiguous().view(-1) + m_indices = torch.arange(0, topk, dtype=torch.int) + m_indices = m_indices.unsqueeze(-1).expand(topk, M).contiguous().view(-1) p("m_indices", m_indices) pp(m_indices) p("topk_ids", topk_ids) @@ -499,6 +499,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) #itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, From 4e816055d67544b45c98ffd15b9fba430ed23caf Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Mar 2025 20:23:10 +0000 Subject: [PATCH 056/171] enable more tests Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 1c5f9c2ce645..6831ab139b17 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -38,14 +38,12 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. -#M_moe = [1, 7, 83, 512, 2048] -M_moe = [1, 2, 8, 84, 512] #, 2048] +M_moe = [1, 2, 7, 83, 512, 2048] N_moe = [128, 256, 4608] # [128, 4608, 13824] K_moe = [256, 512, 7168] # [256, 7168, 13824] BLOCK_SIZE = [[128, 128]] -#E = [8, 24] # [8, 24, 128, 256] -E = [2] #, 8] #, 16] # [8, 24, 128, 256] -TOP_KS = [1] # [1, 2, 6] +E = [2, 8, 16] # 24 # [8, 24, 128, 256] +TOP_KS = [1, 2] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -226,11 +224,11 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): def p(s, t): - print(f"{s}: {t.shape}, {t.dtype}") + #print(f"{s}: {t.shape}, {t.dtype}") pass def pp(x): - print(x) + #print(x) pass @pytest.mark.parametrize( @@ -505,9 +503,9 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): - # only aligned sizes - if (N % 128 != 0 or K % 128 != 0): - pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") + # only aligned sizes or supported topk + if (N % 128 != 0 or K % 128 != 0 or topk > 1): + pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}, {topk}") pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") From 9f21aa22e0ea4cfea645267c00440f1863e3b4c7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Mar 2025 20:37:22 +0000 Subject: [PATCH 057/171] format Signed-off-by: Bill Nell --- requirements/test.txt | 6 ++++ tests/kernels/test_block_fp8.py | 52 ++++++++++++++++++++------------- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index 2e8121e3882e..03093e134524 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -132,6 +132,10 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval +exceptiongroup==1.2.2 + # via + # anyio + # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -772,9 +776,11 @@ typing-extensions==4.12.2 # huggingface-hub # librosa # mistral-common + # multidict # pqdm # pydantic # pydantic-core + # rich # torch # typer tzdata==2024.2 diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 6831ab139b17..08f620789f7f 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -42,7 +42,7 @@ N_moe = [128, 256, 4608] # [128, 4608, 13824] K_moe = [256, 512, 7168] # [256, 7168, 13824] BLOCK_SIZE = [[128, 128]] -E = [2, 8, 16] # 24 # [8, 24, 128, 256] +E = [2, 8, 16] # 24 # [8, 24, 128, 256] TOP_KS = [1, 2] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -227,10 +227,12 @@ def p(s, t): #print(f"{s}: {t.shape}, {t.dtype}") pass + def pp(x): #print(x) pass + @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, @@ -298,8 +300,10 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.03 + ######################################################################################### + def per_token_cast_to_fp8( x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 @@ -376,11 +380,16 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # ref_out = torch.einsum('gmk,gnk->gmn', x, y) -def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + +def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_shape): """Fused moe with block-wise quantization using native torch.""" B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) + out = torch.zeros(B * topk, + w2.shape[1], + dtype=torch.bfloat16, + device=a.device) score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) @@ -393,24 +402,24 @@ def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, bloc mask = topk_ids == i if mask.sum(): inter_out = torch.empty((a_q[mask].shape[0], w1[i].shape[0]), - device=a_q.device, dtype=torch.bfloat16) - deep_gemm.gemm_fp8_fp8_bf16_nt((a_q[mask].to(dtype=torch.float8_e4m3fn), a_s[mask]), - (w1[i], w1_s[i]), - inter_out) + device=a_q.device, + dtype=torch.bfloat16) + deep_gemm.gemm_fp8_fp8_bf16_nt( + (a_q[mask].to(dtype=torch.float8_e4m3fn), a_s[mask]), + (w1[i], w1_s[i]), inter_out) act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) tmp_out = torch.empty((act_out.shape[0], w2[i].shape[0]), - device=a_q.device, dtype=torch.bfloat16) + device=a_q.device, + dtype=torch.bfloat16) deep_gemm.gemm_fp8_fp8_bf16_nt((act_out_q, act_out_s), - (w2[i], w2_s[i]), - tmp_out) + (w2[i], w2_s[i]), tmp_out) out[mask] = tmp_out return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" @@ -433,8 +442,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, row_size = max((topk * M) // num_groups, 1) # 2 *? sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(topk_ids, M * topk, num_groups, None) - ) + moe_align_block_size(topk_ids, M * topk, num_groups, None)) m_indices = expert_ids #assert m_indices.numel() == num_groups * M * topk #pp(f"num_tokens_post_padded = {num_tokens_post_padded}") @@ -496,9 +504,10 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, # topk > 1 does not work @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) - #itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, SEEDS)) + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS)) +#itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) +#itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -507,7 +516,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, if (N % 128 != 0 or K % 128 != 0 or topk > 1): pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}, {topk}") - pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") + pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}" + ) torch.set_printoptions(profile="full") @@ -529,9 +539,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, if False: score = torch.empty((M, E), dtype=dtype) for i in range(M): - score[i] = torch.full((E,), 1.0/(i+1), dtype=dtype) + score[i] = torch.full((E, ), 1.0 / (i + 1), dtype=dtype) for i in range(score.numel()): - score.view(-1)[i] = 1.0/(i+1) + score.view(-1)[i] = 1.0 / (i + 1) score = torch.zeros((M, E), dtype=dtype) p("score", score) #pp(score) @@ -597,8 +607,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - ref_out = deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) + ref_out = deep_gemm_matmul_w8a8_block_fp8_moe( + a, w1, w2, w1_s, w2_s, score, topk, block_size) else: out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) From 10ba95d11c5f00c036d801881f89e440b90ace66 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 4 Mar 2025 21:59:00 +0000 Subject: [PATCH 058/171] use fused_topk for unit test Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 159 +++++++++++------- .../layers/fused_moe/fused_moe.py | 2 +- 2 files changed, 103 insertions(+), 58 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 08f620789f7f..05e4de3e3f7b 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -12,7 +12,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size +from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size, fused_topk from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -38,12 +38,16 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. -M_moe = [1, 2, 7, 83, 512, 2048] +#M_moe = [1, 2, 7, 83] #, 512, 2048] +M_moe = [128, 512, 2048] N_moe = [128, 256, 4608] # [128, 4608, 13824] K_moe = [256, 512, 7168] # [256, 7168, 13824] +M_moe_small = [128, 512] +N_moe_small = [128, 256] +K_moe_small = [256, 512] BLOCK_SIZE = [[128, 128]] -E = [2, 8, 16] # 24 # [8, 24, 128, 256] -TOP_KS = [1, 2] # [1, 2, 6] +E = [2, 8] #, 16] # 24 # [8, 24, 128, 256] +TOP_KS = [1, 2, 6] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -224,7 +228,7 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): def p(s, t): - #print(f"{s}: {t.shape}, {t.dtype}") + print(f"{s}: {t.shape}, {t.dtype}") pass @@ -385,13 +389,18 @@ def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using native torch.""" B, D = a.shape + pre_a = a a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) + if False: + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + else: + topk_weight, topk_ids = fused_topk(pre_a, score.float(), topk, False) + del pre_a topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) @@ -420,18 +429,25 @@ def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) -def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, +def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" num_groups = w1.shape[0] M, K = a.shape N = w2.shape[-1] + pre_a = a a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + inter_out = torch.empty((a.shape[0], w1[0].shape[0]), dtype=torch.bfloat16, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) + + if True: + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + else: + topk_weight, topk_ids = fused_topk(pre_a, score.float(), topk, False) + del pre_a topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) @@ -439,26 +455,39 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, pp(f"BLOCK_M {block_m}") p("A", a) - row_size = max((topk * M) // num_groups, 1) # 2 *? - - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(topk_ids, M * topk, num_groups, None)) - m_indices = expert_ids - #assert m_indices.numel() == num_groups * M * topk - #pp(f"num_tokens_post_padded = {num_tokens_post_padded}") - #p("expert ids", expert_ids) - _, block_k = block_shape[0], block_shape[1] a_q, a_s = per_token_group_quant_fp8(a, block_m) #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) #print(f"FIRST GEMM {a_q.shape}") - # m_indices maps to expert_ids - m_indices = torch.arange(0, topk, dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand(topk, M).contiguous().view(-1) + if False: + m_indices = torch.arange(0, M * topk, dtype=torch.int) + #m_indices = m_indices.unsqueeze(-1).expand(M, topk).contiguous().view(-1) + m_indices = m_indices.unsqueeze(-1).contiguous().view(-1) + else: + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(topk_ids, 1, M, None)) + #sorted_token_ids, _ = torch.sort(sorted_token_ids, 0, descending=False) + #sorted_token_ids = sorted_token_ids.unsqueeze(-1).expand(num_groups, M).contiguous().view(-1) + # ??? + #sorted_token_ids = sorted_token_ids.unsqueeze(-1).expand(M, topk).contiguous().view(-1) + p("SORTED", sorted_token_ids) + pp(sorted_token_ids) + print(sorted_token_ids) + pp(f"mask = {sorted_token_ids == M}") + #sorted_token_ids[sorted_token_ids == 2*M] = -1 + pp(sorted_token_ids) + print(f"max = {torch.max(sorted_token_ids)}, M={M}, topk={topk}") + assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 + m_indices = sorted_token_ids + #assert m_indices.numel() == num_groups * M * topk + #pp(f"num_tokens_post_padded = {num_tokens_post_padded}") + #p("expert ids", expert_ids) + p("m_indices", m_indices) - pp(m_indices) + #print(f"m_indices {m_indices.shape} {sorted_token_ids.shape}") + #pp(m_indices) p("topk_ids", topk_ids) #pp(topk_ids) p("topk_weight", topk_weight) @@ -476,11 +505,13 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, pp("FIRST GEMM DONE") - #pp(f"DG {inter_out.shape} {inter_out}") + pp(f"DG {inter_out.shape} {inter_out}") act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) + p("act_out", act_out) + pp("SECOND GEMM") out = torch.empty(act_out.shape[0], @@ -501,23 +532,36 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) +def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: + dimensions = [] + + for index, _ in enumerate(shape): + if index != dim: + dimension = 1 + else: + dimension = shape[index] + + dimensions = [*dimensions, dimension] + + return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape) + + # topk > 1 does not work @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) -#itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) -#itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, SEEDS)) + #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) + #itertools.product([1], [128], [256], [3], [3], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): # only aligned sizes or supported topk - if (N % 128 != 0 or K % 128 != 0 or topk > 1): + if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0 or topk == 1 or topk > E): pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}, {topk}") - pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}" - ) + print(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") @@ -535,39 +579,39 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) - #score = torch.randn((M, E), dtype=dtype) - if False: - score = torch.empty((M, E), dtype=dtype) - for i in range(M): - score[i] = torch.full((E, ), 1.0 / (i + 1), dtype=dtype) - for i in range(score.numel()): - score.view(-1)[i] = 1.0 / (i + 1) - score = torch.zeros((M, E), dtype=dtype) + #score = torch.randn((M, E), dtype=dtype) # does not work + #score = torch.ones((M, E), dtype=dtype) # works + #score = torch.zeros((M, E), dtype=dtype) # works + #score = torch.full((M, E), 0.5, dtype=dtype) # works + #score = torch.empty((M, E), dtype=dtype) + #for i in range(M): # works + # score[i] = torch.full((E, ), 1.0 / (i + 1), dtype=dtype) + #score = torch.empty((M, E), dtype=dtype) + #for i in range(score.numel()): # works + # score.view(-1)[i] = 1.0 / (i + 1) + score = iota((M, E), dtype=dtype) p("score", score) #pp(score) - num_groups = E block_n, block_k = block_size[0], block_size[1] n_tiles_w1 = ((2 * N) + block_n - 1) // block_n k_tiles_w1 = (K + block_k - 1) // block_k n_tiles_w2 = (K + block_n - 1) // block_n k_tiles_w2 = (N + block_k - 1) // block_k - # TODO: turn these back to empty calls - w1 = torch.zeros_like(w1_bf16, dtype=torch.float8_e4m3fn) - w2 = torch.zeros_like(w2_bf16, dtype=torch.float8_e4m3fn) + # TODO: change these to zeros to test out groups + w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) + w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) - w1_s = torch.zeros((num_groups, n_tiles_w1, k_tiles_w1), - dtype=torch.float32) - w2_s = torch.zeros((num_groups, n_tiles_w2, k_tiles_w2), - dtype=torch.float32) + w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) - assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) + assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] # TODO: fix later - pp("For now, only convert the first group, the rest will be 0") - for i in range(num_groups): + #pp("For now, only convert the first group, the rest will be 0") + for i in range(E): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) @@ -595,10 +639,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, with set_current_vllm_config(vllm_config): if False: out = fused_moe( - a, + a, #hidden w1, w2, - score, + score, #gating topk, renormalize=False, use_fp8_w8a8=True, @@ -610,14 +654,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, ref_out = deep_gemm_matmul_w8a8_block_fp8_moe( a, w1, w2, w1_s, w2_s, score, topk, block_size) else: - out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) - ref_out = fused_moe( - a, + a, #hidden w1, w2, - score, + score, #gating topk, renormalize=False, use_fp8_w8a8=True, @@ -626,6 +667,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) + out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, + topk, block_size) + + #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 9a0c56c75ac7..2529582b9b2f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1451,7 +1451,7 @@ def fused_moe( MoE layer. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - - use_grouped_topk: If True, use grouped_topk instead of fused_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_top note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. From a46f3d40377806ff5fa93ec9586e4f4572f3a323 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 5 Mar 2025 04:18:32 +0000 Subject: [PATCH 059/171] every other block correct Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 80 +++++++++++++------ .../layers/fused_moe/fused_moe.py | 10 ++- 2 files changed, 64 insertions(+), 26 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 05e4de3e3f7b..98eef3475dba 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -228,12 +228,12 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): def p(s, t): - print(f"{s}: {t.shape}, {t.dtype}") + print(f"{s}: {t.shape}, {t.dtype}\n{t}") pass def pp(x): - #print(x) + print(x) pass @@ -436,35 +436,49 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, M, K = a.shape N = w2.shape[-1] pre_a = a - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + # to try: turn into 3d view here, do not flatten until after quantization + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) # orig + p("A'", a) + print(a) - inter_out = torch.empty((a.shape[0], w1[0].shape[0]), - dtype=torch.bfloat16, - device=a.device) - - if True: - score = torch.softmax(score, dim=-1, dtype=torch.float32) + if False: + scpore = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) + topk_ids, w_sort = topk_ids.sort() + topk_weight = torch.gather(topk_weight, dim=1, index=w_sort) else: topk_weight, topk_ids = fused_topk(pre_a, score.float(), topk, False) - del pre_a - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) + #del pre_a + + # pre_a.shape[0] * topk_ids.shape[1] + inter_out = torch.empty((pre_a.shape[0] * topk, w1[0].shape[0]), + dtype=torch.bfloat16, + device=a.device) block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - pp(f"BLOCK_M {block_m}") - p("A", a) + pp(f"M {M}, BLOCK_M {block_m}") + #p("A", a) _, block_k = block_shape[0], block_shape[1] a_q, a_s = per_token_group_quant_fp8(a, block_m) + #a_q, a_s = per_token_cast_to_fp8(a) + + #a_q = a_q.view(a_q.shape[0], -1, a_q.shape[1]).repeat(topk, 1, 1).reshape(-1, a_q.shape[1]) + #a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(topk, 1, 1).reshape(-1, a_s.shape[1]) #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) #print(f"FIRST GEMM {a_q.shape}") - if False: - m_indices = torch.arange(0, M * topk, dtype=torch.int) - #m_indices = m_indices.unsqueeze(-1).expand(M, topk).contiguous().view(-1) - m_indices = m_indices.unsqueeze(-1).contiguous().view(-1) + if True: + m_indices = torch.arange(0, topk, dtype=torch.int) + m_indices = m_indices.unsqueeze(-1).expand(topk, M).contiguous().view(-1) + #m_indices = m_indices.unsqueeze(-1).contiguous().view(-1) + elif True: + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, 1, num_groups, None) + #assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 + m_indices = sorted_token_ids + p("SORTED", m_indices) + print(m_indices) else: sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(topk_ids, 1, M, None)) @@ -485,6 +499,9 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #pp(f"num_tokens_post_padded = {num_tokens_post_padded}") #p("expert ids", expert_ids) + # must happen after align block size + #topk_weight = topk_weight.view(-1) + p("m_indices", m_indices) #print(f"m_indices {m_indices.shape} {sorted_token_ids.shape}") #pp(m_indices) @@ -494,6 +511,12 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #pp(topk_weight) pp("FIRST GEMM") + pp(f"E = {num_groups}") + p("A", a_q) + p("A_s", a_s) + p("B", w1) + p("B_s", w1_s) + p("m_indices", m_indices) if True: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( @@ -503,22 +526,28 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a_q, a_s), (w1, w1_s), inter_out, topk_ids, M) + p("out", inter_out) pp("FIRST GEMM DONE") - pp(f"DG {inter_out.shape} {inter_out}") + #pp(f"DG {inter_out.shape} {inter_out}") act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - p("act_out", act_out) - - pp("SECOND GEMM") - out = torch.empty(act_out.shape[0], w2.shape[1], dtype=torch.bfloat16, device=a.device) + pp("SECOND GEMM") + pp(f"E = {num_groups}") + p("A", act_out) + p("A_s", act_out_s) + p("B", w2) + p("B_s", w2_s) + p("topk_weights", topk_weight) + p("m_indices", m_indices) + if True: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) @@ -526,6 +555,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( (act_out_q, act_out_s), (w2, w2_s), out, topk_ids, M) + p("out", out) pp("SECOND GEMM DONE") return (out.view(M, -1, w2.shape[1]) * @@ -550,9 +580,9 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) - #itertools.product([1], [128], [256], [3], [3], [[128, 128]], DTYPES, SEEDS)) + itertools.product([128], [128], [256], [2], [2], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2529582b9b2f..89c8e1e5aacf 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -634,6 +634,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor, **config, ) + p("fused_out", C) + print(f"END {'SECOND' if mul_routed_weight else 'FIRST'} FUSED_GEMM") + # Adapted from: https://github.com/sgl-project/sglang/pull/2628 def get_config_file_name(E: int, @@ -1304,6 +1307,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) + print(f"NUM CHUNKS = {(num_tokens // CHUNK_SIZE) + 1}") + print(f"FUSED A {hidden_states.shape}, {hidden_states}") + for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1360,7 +1366,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - #print(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") + print(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, @@ -1483,6 +1489,8 @@ def fused_moe( - torch.Tensor: The output tensor after applying the MoE layer. """ + print(f"FUSED SCORES {hidden_states.shape} {gating_output.shape}") + if use_grouped_topk: assert num_expert_group is not None and topk_group is not None topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, From 4cf77705834d764f4b8ebf6a279d72420e03d074 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 5 Mar 2025 21:14:46 +0000 Subject: [PATCH 060/171] working Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 50 ++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 98eef3475dba..3af652dae6a7 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -229,6 +229,7 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): def p(s, t): print(f"{s}: {t.shape}, {t.dtype}\n{t}") + #print(f"{s}: {t.shape}, {t.dtype}\n{t.flatten()}") pass @@ -429,6 +430,12 @@ def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) +# repeat_interleaved. +# shuffle input by token ids +# unshuffle output by argsorted token ids +# argsort token ids + + def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" @@ -437,9 +444,10 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, N = w2.shape[-1] pre_a = a # to try: turn into 3d view here, do not flatten until after quantization - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) # orig + #a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) # orig + #a = a.view(M, -1, K).repeat_interleave(topk, dim=0).reshape(-1, K) # orig p("A'", a) - print(a) + #print(a) if False: scpore = torch.softmax(score, dim=-1, dtype=torch.float32) @@ -460,25 +468,26 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #p("A", a) _, block_k = block_shape[0], block_shape[1] - a_q, a_s = per_token_group_quant_fp8(a, block_m) - #a_q, a_s = per_token_cast_to_fp8(a) + #a_q, a_s = per_token_group_quant_fp8(a, block_m) + #a_q, a_s = per_token_cast_to_fp8(a) #a_q = a_q.view(a_q.shape[0], -1, a_q.shape[1]).repeat(topk, 1, 1).reshape(-1, a_q.shape[1]) #a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(topk, 1, 1).reshape(-1, a_s.shape[1]) + #p("A_q", a_q) + #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) #print(f"FIRST GEMM {a_q.shape}") - if True: + if False: m_indices = torch.arange(0, topk, dtype=torch.int) m_indices = m_indices.unsqueeze(-1).expand(topk, M).contiguous().view(-1) #m_indices = m_indices.unsqueeze(-1).contiguous().view(-1) elif True: - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, 1, num_groups, None) + sorted_token_ids, expert_ids, _ = moe_align_block_size(topk_ids, 1, num_groups, None) #assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 m_indices = sorted_token_ids - p("SORTED", m_indices) - print(m_indices) + p("SORTED", sorted_token_ids) else: sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(topk_ids, 1, M, None)) @@ -499,6 +508,25 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #pp(f"num_tokens_post_padded = {num_tokens_post_padded}") #p("expert ids", expert_ids) + #a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) # orig + + a_q, a_s = per_token_group_quant_fp8(a, block_m) + p("a_s_0", a_s) + + a_q = a_q.view(a_q.shape[0], -1, a_q.shape[1]).repeat(1, topk, 1).reshape(-1, a_q.shape[1]) # orig + a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) # orig + + print(f"max = {topk*M}") + # gather? + a_q = a_q.view(dtype=torch.uint8)[sorted_token_ids, ...].view(dtype=torch.float8_e4m3fn) + a_s = a_s[sorted_token_ids] + #a_s = torch.gather(a_s, dim=0, index=sorted_token_ids.clamp((topk*M)-1).view(-1, 1).to(dtype=torch.int64)) + + m_indices = expert_ids # torch.repeat_interleave(expert_ids, topk, dim=0) + + p("a_q_s", a_q) + p("a_s_s", a_s) + # must happen after align block size #topk_weight = topk_weight.view(-1) @@ -526,7 +554,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a_q, a_s), (w1, w1_s), inter_out, topk_ids, M) - p("out", inter_out) + p("inter_out", inter_out) pp("FIRST GEMM DONE") #pp(f"DG {inter_out.shape} {inter_out}") @@ -558,7 +586,9 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, p("out", out) pp("SECOND GEMM DONE") - return (out.view(M, -1, w2.shape[1]) * + inv_perm = torch.argsort(sorted_token_ids) + + return (out[inv_perm].view(M, -1, w2.shape[1]) * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) From 65a3ef35a8c79561bf498ce8e6be1202dd24b4db Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 5 Mar 2025 21:46:41 +0000 Subject: [PATCH 061/171] enable more tests Signed-off-by: Bill Nell --- requirements/test.txt | 6 ----- tests/kernels/test_block_fp8.py | 18 ++++++------- .../layers/fused_moe/fused_moe.py | 26 +++++++++++++------ 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index 03093e134524..2e8121e3882e 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -132,10 +132,6 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval -exceptiongroup==1.2.2 - # via - # anyio - # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -776,11 +772,9 @@ typing-extensions==4.12.2 # huggingface-hub # librosa # mistral-common - # multidict # pqdm # pydantic # pydantic-core - # rich # torch # typer tzdata==2024.2 diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 3af652dae6a7..df63fd520734 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -42,9 +42,9 @@ M_moe = [128, 512, 2048] N_moe = [128, 256, 4608] # [128, 4608, 13824] K_moe = [256, 512, 7168] # [256, 7168, 13824] -M_moe_small = [128, 512] -N_moe_small = [128, 256] -K_moe_small = [256, 512] +M_moe_small = [128, 512, 2048] +N_moe_small = [128, 256, 4608] +K_moe_small = [256, 512, 7168] BLOCK_SIZE = [[128, 128]] E = [2, 8] #, 16] # 24 # [8, 24, 128, 256] TOP_KS = [1, 2, 6] # [1, 2, 6] @@ -228,13 +228,13 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): def p(s, t): - print(f"{s}: {t.shape}, {t.dtype}\n{t}") + #print(f"{s}: {t.shape}, {t.dtype}\n{t}") #print(f"{s}: {t.shape}, {t.dtype}\n{t.flatten()}") pass def pp(x): - print(x) + #print(x) pass @@ -516,7 +516,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_q = a_q.view(a_q.shape[0], -1, a_q.shape[1]).repeat(1, topk, 1).reshape(-1, a_q.shape[1]) # orig a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) # orig - print(f"max = {topk*M}") + pp(f"max = {topk*M}") # gather? a_q = a_q.view(dtype=torch.uint8)[sorted_token_ids, ...].view(dtype=torch.float8_e4m3fn) a_s = a_s[sorted_token_ids] @@ -610,9 +610,9 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) - itertools.product([128], [128], [256], [2], [2], [[128, 128]], DTYPES, SEEDS)) + #itertools.product([128], [128], [256], [2], [2], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -621,7 +621,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0 or topk == 1 or topk > E): pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}, {topk}") - print(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") + pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 89c8e1e5aacf..6fbb7295f3a8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -29,6 +29,17 @@ logger = init_logger(__name__) +def p(s, t): + #print(f"{s}: {t.shape}, {t.dtype}\n{t.flatten()}") + #print(f"{s}: {t.shape}, {t.dtype}\n{t}") + pass + + +def pp(x): + #print(x) + pass + + @triton.jit def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, @@ -503,7 +514,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, M = A.shape[0] num_tokens = M * top_k - # EM = num_groups EM = sorted_token_ids.shape[0] if A.shape[0] < config["BLOCK_SIZE_M"]: # optimize for small batch_size. @@ -635,7 +645,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ) p("fused_out", C) - print(f"END {'SECOND' if mul_routed_weight else 'FIRST'} FUSED_GEMM") + pp(f"END {'SECOND' if mul_routed_weight else 'FIRST'} FUSED_GEMM") # Adapted from: https://github.com/sgl-project/sglang/pull/2628 @@ -1232,7 +1242,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None): + block_shape: Optional[List[int]] = None) -> torch.Tensor: # Check constraints. if use_int4_w4a16: assert hidden_states.shape[1] // 2 == w1.shape[ @@ -1307,8 +1317,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - print(f"NUM CHUNKS = {(num_tokens // CHUNK_SIZE) + 1}") - print(f"FUSED A {hidden_states.shape}, {hidden_states}") + pp(f"NUM CHUNKS = {(num_tokens // CHUNK_SIZE) + 1}") + pp(f"FUSED A {hidden_states.shape}, {hidden_states}") for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, @@ -1366,7 +1376,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - print(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") + pp(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, @@ -1457,7 +1467,7 @@ def fused_moe( MoE layer. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - - use_grouped_topk: If True, use grouped_topk instead of fused_top + - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. @@ -1489,7 +1499,7 @@ def fused_moe( - torch.Tensor: The output tensor after applying the MoE layer. """ - print(f"FUSED SCORES {hidden_states.shape} {gating_output.shape}") + pp(f"FUSED SCORES {hidden_states.shape} {gating_output.shape}") if use_grouped_topk: assert num_expert_group is not None and topk_group is not None From 65ce6e74790a02c390aa6091f3965e2a5fb014e7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 5 Mar 2025 22:04:44 +0000 Subject: [PATCH 062/171] working tests w/permute Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 8 ++++---- vllm/model_executor/layers/fused_moe/fused_moe.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index df63fd520734..26b455ad1469 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -609,8 +609,8 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: # topk > 1 does not work @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) #itertools.product([128], [128], [256], [2], [2], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() @@ -618,8 +618,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): # only aligned sizes or supported topk - if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0 or topk == 1 or topk > E): - pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}, {topk}") + if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0): + pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6fbb7295f3a8..c2a8b5409756 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1318,7 +1318,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, out_hidden_states = torch.empty_like(hidden_states) pp(f"NUM CHUNKS = {(num_tokens // CHUNK_SIZE) + 1}") - pp(f"FUSED A {hidden_states.shape}, {hidden_states}") + #pp(f"FUSED A {hidden_states.shape}, {hidden_states}") for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, @@ -1376,7 +1376,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - pp(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") + #pp(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, @@ -1499,7 +1499,7 @@ def fused_moe( - torch.Tensor: The output tensor after applying the MoE layer. """ - pp(f"FUSED SCORES {hidden_states.shape} {gating_output.shape}") + #pp(f"FUSED SCORES {hidden_states.shape} {gating_output.shape}") if use_grouped_topk: assert num_expert_group is not None and topk_group is not None From 75b376c35391a2d0cf4f705085cb4fb4e5134e29 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 5 Mar 2025 22:18:10 +0000 Subject: [PATCH 063/171] cleanups Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 257 +++--------------- .../layers/fused_moe/fused_moe.py | 21 -- 2 files changed, 43 insertions(+), 235 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 26b455ad1469..a10c7cc905ce 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -12,7 +12,8 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size, fused_topk +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -26,28 +27,17 @@ NUM_TOKENS = [7, 83, 2048] D = [512, 4096, 5120, 13824] GROUP_SIZE = [64, 128, 256, 512] -#M = [1, 7, 83, 512, 2048] - -M = [1, 8, 84, 512, 2048, 4096] +M = [1, 7, 8, 83, 84, 512, 2048, 4096] N = [128, 512, 1024, 4096, 7748, 13824, 7168] K = [256, 4096, 5120, 3884, 13824, 16384] - -#M = [128] -#N = [24576] -#K = [1536] - # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. -#M_moe = [1, 2, 7, 83] #, 512, 2048] -M_moe = [128, 512, 2048] +M_moe = [1, 2, 7, 83, 128, 512, 2048] N_moe = [128, 256, 4608] # [128, 4608, 13824] K_moe = [256, 512, 7168] # [256, 7168, 13824] -M_moe_small = [128, 512, 2048] -N_moe_small = [128, 256, 4608] -K_moe_small = [256, 512, 7168] BLOCK_SIZE = [[128, 128]] -E = [2, 8] #, 16] # 24 # [8, 24, 128, 256] -TOP_KS = [1, 2, 6] # [1, 2, 6] +E = [2, 8, 16, 24] +TOP_KS = [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -227,17 +217,6 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 -def p(s, t): - #print(f"{s}: {t.shape}, {t.dtype}\n{t}") - #print(f"{s}: {t.shape}, {t.dtype}\n{t.flatten()}") - pass - - -def pp(x): - #print(x) - pass - - @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, @@ -275,12 +254,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) - p("a", a) - p("w1", w1) - p("w1_s", w1_s) - p("w2", w2) - p("w2_s", w2_s) - with set_current_vllm_config(vllm_config): out = fused_moe( a, @@ -306,19 +279,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): assert rel_diff < 0.03 -######################################################################################### - - -def per_token_cast_to_fp8( - x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 and x.size(1) % 128 == 0 - m, n = x.shape - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to( - torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) - - def per_block_cast_to_fp8( x: torch.Tensor, block_size_n: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: @@ -381,29 +341,19 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 -################################################################################### - -# ref_out = torch.einsum('gmk,gnk->gmn', x, y) - - def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): - """Fused moe with block-wise quantization using native torch.""" + """Fused moe with block-wise quantization using DeepGemm.""" + topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + B, D = a.shape - pre_a = a a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) - if False: - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - else: - topk_weight, topk_ids = fused_topk(pre_a, score.float(), topk, False) - del pre_a - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) _, block_k = block_shape[0], block_shape[1] a_q, a_s = per_token_group_quant_fp8(a, block_k) @@ -430,134 +380,45 @@ def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) -# repeat_interleaved. -# shuffle input by token ids -# unshuffle output by argsorted token ids -# argsort token ids - - def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" num_groups = w1.shape[0] M, K = a.shape N = w2.shape[-1] - pre_a = a - # to try: turn into 3d view here, do not flatten until after quantization - #a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) # orig - #a = a.view(M, -1, K).repeat_interleave(topk, dim=0).reshape(-1, K) # orig - p("A'", a) - #print(a) - - if False: - scpore = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_ids, w_sort = topk_ids.sort() - topk_weight = torch.gather(topk_weight, dim=1, index=w_sort) - else: - topk_weight, topk_ids = fused_topk(pre_a, score.float(), topk, False) - #del pre_a - - # pre_a.shape[0] * topk_ids.shape[1] - inter_out = torch.empty((pre_a.shape[0] * topk, w1[0].shape[0]), + + topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) + + inter_out = torch.empty((M * topk, w1[0].shape[0]), dtype=torch.bfloat16, device=a.device) block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - pp(f"M {M}, BLOCK_M {block_m}") - #p("A", a) _, block_k = block_shape[0], block_shape[1] - #a_q, a_s = per_token_group_quant_fp8(a, block_m) - #a_q, a_s = per_token_cast_to_fp8(a) - #a_q = a_q.view(a_q.shape[0], -1, a_q.shape[1]).repeat(topk, 1, 1).reshape(-1, a_q.shape[1]) - #a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(topk, 1, 1).reshape(-1, a_s.shape[1]) - - #p("A_q", a_q) - - #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) - #print(f"FIRST GEMM {a_q.shape}") - - if False: - m_indices = torch.arange(0, topk, dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand(topk, M).contiguous().view(-1) - #m_indices = m_indices.unsqueeze(-1).contiguous().view(-1) - elif True: - sorted_token_ids, expert_ids, _ = moe_align_block_size(topk_ids, 1, num_groups, None) - #assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 - m_indices = sorted_token_ids - p("SORTED", sorted_token_ids) - else: - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(topk_ids, 1, M, None)) - #sorted_token_ids, _ = torch.sort(sorted_token_ids, 0, descending=False) - #sorted_token_ids = sorted_token_ids.unsqueeze(-1).expand(num_groups, M).contiguous().view(-1) - # ??? - #sorted_token_ids = sorted_token_ids.unsqueeze(-1).expand(M, topk).contiguous().view(-1) - p("SORTED", sorted_token_ids) - pp(sorted_token_ids) - print(sorted_token_ids) - pp(f"mask = {sorted_token_ids == M}") - #sorted_token_ids[sorted_token_ids == 2*M] = -1 - pp(sorted_token_ids) - print(f"max = {torch.max(sorted_token_ids)}, M={M}, topk={topk}") - assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 - m_indices = sorted_token_ids - #assert m_indices.numel() == num_groups * M * topk - #pp(f"num_tokens_post_padded = {num_tokens_post_padded}") - #p("expert ids", expert_ids) - - #a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) # orig + sorted_token_ids, expert_ids, _ = moe_align_block_size( + topk_ids, 1, num_groups, None) + #assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 + m_indices = sorted_token_ids a_q, a_s = per_token_group_quant_fp8(a, block_m) - p("a_s_0", a_s) - a_q = a_q.view(a_q.shape[0], -1, a_q.shape[1]).repeat(1, topk, 1).reshape(-1, a_q.shape[1]) # orig - a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) # orig + a_q = a_q.view(a_q.shape[0], -1, + a_q.shape[1]).repeat(1, topk, + 1).reshape(-1, a_q.shape[1]) # orig + a_s = a_s.view(a_s.shape[0], -1, + a_s.shape[1]).repeat(1, topk, + 1).reshape(-1, a_s.shape[1]) # orig - pp(f"max = {topk*M}") - # gather? - a_q = a_q.view(dtype=torch.uint8)[sorted_token_ids, ...].view(dtype=torch.float8_e4m3fn) + a_q = a_q.view(dtype=torch.uint8)[sorted_token_ids, + ...].view(dtype=torch.float8_e4m3fn) a_s = a_s[sorted_token_ids] - #a_s = torch.gather(a_s, dim=0, index=sorted_token_ids.clamp((topk*M)-1).view(-1, 1).to(dtype=torch.int64)) - - m_indices = expert_ids # torch.repeat_interleave(expert_ids, topk, dim=0) - p("a_q_s", a_q) - p("a_s_s", a_s) + m_indices = expert_ids # torch.repeat_interleave(expert_ids, topk, dim=0) - # must happen after align block size - #topk_weight = topk_weight.view(-1) - - p("m_indices", m_indices) - #print(f"m_indices {m_indices.shape} {sorted_token_ids.shape}") - #pp(m_indices) - p("topk_ids", topk_ids) - #pp(topk_ids) - p("topk_weight", topk_weight) - #pp(topk_weight) - - pp("FIRST GEMM") - pp(f"E = {num_groups}") - p("A", a_q) - p("A_s", a_s) - p("B", w1) - p("B_s", w1_s) - p("m_indices", m_indices) - - if True: - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (a_q, a_s), (w1, w1_s), inter_out, m_indices) - else: - topk_ids = topk_ids.to(dtype=torch.int32) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a_q, a_s), (w1, w1_s), - inter_out, topk_ids, M) - - p("inter_out", inter_out) - pp("FIRST GEMM DONE") - - #pp(f"DG {inter_out.shape} {inter_out}") + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), + inter_out, m_indices) act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) @@ -567,24 +428,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, dtype=torch.bfloat16, device=a.device) - pp("SECOND GEMM") - pp(f"E = {num_groups}") - p("A", act_out) - p("A_s", act_out_s) - p("B", w2) - p("B_s", w2_s) - p("topk_weights", topk_weight) - p("m_indices", m_indices) - - if True: - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - else: - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( - (act_out_q, act_out_s), (w2, w2_s), out, topk_ids, M) - - p("out", out) - pp("SECOND GEMM DONE") + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (act_out_q, act_out_s), (w2, w2_s), out, m_indices) inv_perm = torch.argsort(sorted_token_ids) @@ -606,13 +451,10 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape) -# topk > 1 does not work @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) - #itertools.product([128], [128], [256], [2], [2], [[128, 128]], DTYPES, SEEDS)) + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -621,7 +463,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0): pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") - pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") + #pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") @@ -639,6 +481,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) + # TODO!!!!!!!!!!!! #score = torch.randn((M, E), dtype=dtype) # does not work #score = torch.ones((M, E), dtype=dtype) # works #score = torch.zeros((M, E), dtype=dtype) # works @@ -650,7 +493,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, #for i in range(score.numel()): # works # score.view(-1)[i] = 1.0 / (i + 1) score = iota((M, E), dtype=dtype) - p("score", score) + #p("score", score) #pp(score) block_n, block_k = block_size[0], block_size[1] @@ -659,7 +502,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, n_tiles_w2 = (K + block_n - 1) // block_n k_tiles_w2 = (N + block_k - 1) // block_k - # TODO: change these to zeros to test out groups w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) @@ -669,8 +511,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] - # TODO: fix later - #pp("For now, only convert the first group, the rest will be 0") for i in range(E): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) @@ -680,29 +520,19 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, # TODO: move size alignment further up when setting up all shapes if w1_sa.shape != w1_s.shape or w2_sa.shape != w2_s.shape: - p("w1_sa", w1_sa) - p("w2_sa", w2_sa) print("UNALIGNED") pytest.skip("UNALIGNED") w1_s = w1_sa w2_s = w2_sa - p("a", a) - p("w1", w1) - #print(w1) - p("w1_s", w1_s) - #print(w1_s) - p("w2", w2) - p("w2_s", w2_s) - with set_current_vllm_config(vllm_config): if False: out = fused_moe( - a, #hidden + a, w1, w2, - score, #gating + score, topk, renormalize=False, use_fp8_w8a8=True, @@ -715,10 +545,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, a, w1, w2, w1_s, w2_s, score, topk, block_size) else: ref_out = fused_moe( - a, #hidden + a, w1, w2, - score, #gating + score, topk, renormalize=False, use_fp8_w8a8=True, @@ -727,9 +557,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, - topk, block_size) - + out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, + score, topk, block_size) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c2a8b5409756..47ab937f7ae8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -29,17 +29,6 @@ logger = init_logger(__name__) -def p(s, t): - #print(f"{s}: {t.shape}, {t.dtype}\n{t.flatten()}") - #print(f"{s}: {t.shape}, {t.dtype}\n{t}") - pass - - -def pp(x): - #print(x) - pass - - @triton.jit def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, @@ -644,9 +633,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, **config, ) - p("fused_out", C) - pp(f"END {'SECOND' if mul_routed_weight else 'FIRST'} FUSED_GEMM") - # Adapted from: https://github.com/sgl-project/sglang/pull/2628 def get_config_file_name(E: int, @@ -1317,9 +1303,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - pp(f"NUM CHUNKS = {(num_tokens // CHUNK_SIZE) + 1}") - #pp(f"FUSED A {hidden_states.shape}, {hidden_states}") - for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1376,8 +1359,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - #pp(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") - if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -1499,8 +1480,6 @@ def fused_moe( - torch.Tensor: The output tensor after applying the MoE layer. """ - #pp(f"FUSED SCORES {hidden_states.shape} {gating_output.shape}") - if use_grouped_topk: assert num_expert_group is not None and topk_group is not None topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, From 416dec44d5935d0972b5ec249297e67da1416150 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 6 Mar 2025 20:56:30 +0000 Subject: [PATCH 064/171] wip Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 130 ++++++++---------- .../layers/fused_moe/fused_moe.py | 48 ++++++- 2 files changed, 102 insertions(+), 76 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index a10c7cc905ce..2d3bc98d4909 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -42,6 +42,15 @@ SEEDS = [0] +def p(s, t): + #print(f"{s}: {t.shape}\n{t}") + pass + +def pp(x): + #print(x) + pass + + def native_per_token_group_quant_fp8(x, group_size, eps=1e-10, @@ -341,45 +350,6 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 -def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_shape): - """Fused moe with block-wise quantization using DeepGemm.""" - topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, - w2.shape[1], - dtype=torch.bfloat16, - device=a.device) - - _, block_k = block_shape[0], block_shape[1] - a_q, a_s = per_token_group_quant_fp8(a, block_k) - a_q = a_q.to(dtype=torch.float32) - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - inter_out = torch.empty((a_q[mask].shape[0], w1[i].shape[0]), - device=a_q.device, - dtype=torch.bfloat16) - deep_gemm.gemm_fp8_fp8_bf16_nt( - (a_q[mask].to(dtype=torch.float8_e4m3fn), a_s[mask]), - (w1[i], w1_s[i]), inter_out) - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - tmp_out = torch.empty((act_out.shape[0], w2[i].shape[0]), - device=a_q.device, - dtype=torch.bfloat16) - deep_gemm.gemm_fp8_fp8_bf16_nt((act_out_q, act_out_s), - (w2[i], w2_s[i]), tmp_out) - out[mask] = tmp_out - - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - - def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" @@ -397,32 +367,53 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, _, block_k = block_shape[0], block_shape[1] - sorted_token_ids, expert_ids, _ = moe_align_block_size( - topk_ids, 1, num_groups, None) + sorted_token_ids, m_indices, num_pad = moe_align_block_size( + topk_ids, 1, num_groups, None) # topk? #assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 - m_indices = sorted_token_ids + + pp(f"num_pad = {num_pad}") + p("orig sorted", sorted_token_ids) + + oob_idx = (sorted_token_ids == M*topk).nonzero() + p("oob_idx", oob_idx) + + sorted_token_ids = sorted_token_ids.clamp(max=(M*topk)-1)[:M*num_groups] + inv_perm = torch.argsort(sorted_token_ids) + + p("m_indices", m_indices) + assert m_indices.numel() == M * topk a_q, a_s = per_token_group_quant_fp8(a, block_m) + # Replicate activations and scales a_q = a_q.view(a_q.shape[0], -1, a_q.shape[1]).repeat(1, topk, - 1).reshape(-1, a_q.shape[1]) # orig + 1).reshape(-1, a_q.shape[1]) a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(1, topk, - 1).reshape(-1, a_s.shape[1]) # orig + 1).reshape(-1, a_s.shape[1]) + # Permute activations according to sorted token ids a_q = a_q.view(dtype=torch.uint8)[sorted_token_ids, ...].view(dtype=torch.float8_e4m3fn) a_s = a_s[sorted_token_ids] - m_indices = expert_ids # torch.repeat_interleave(expert_ids, topk, dim=0) + p("topk_ids", topk_ids) + p("sorted", sorted_token_ids) + p("m_indices", m_indices) + p("topk_weight", topk_weight) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), inter_out, m_indices) + #inter_out = inter_out[inv_perm, ...] + act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) +# act_out_q = act_out_q.view(dtype=torch.uint8)[sorted_token_ids, ...].view(dtype=torch.float8_e4m3fn) +# act_out_s = act_out_s[sorted_token_ids] + out = torch.empty(act_out.shape[0], w2.shape[1], dtype=torch.bfloat16, @@ -431,11 +422,22 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - inv_perm = torch.argsort(sorted_token_ids) + out = out[inv_perm,...] + #topk_weight = topk_weight[inv_perm] + #out[:,num_pad:] = 0 + + #p("inter_out", inter_out) + p("out", out) - return (out[inv_perm].view(M, -1, w2.shape[1]) * + final_out = (out.view(M, -1, w2.shape[1]) * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + p("final_out", final_out) + + # TODO use moe_sum + + return final_out + def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: dimensions = [] @@ -453,17 +455,17 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): - # only aligned sizes or supported topk - if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0): - pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") + # only aligned sizes + if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0 or topk > E): + pytest.skip(f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}") - #pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") + pp(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") @@ -481,20 +483,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) - # TODO!!!!!!!!!!!! - #score = torch.randn((M, E), dtype=dtype) # does not work - #score = torch.ones((M, E), dtype=dtype) # works - #score = torch.zeros((M, E), dtype=dtype) # works - #score = torch.full((M, E), 0.5, dtype=dtype) # works - #score = torch.empty((M, E), dtype=dtype) - #for i in range(M): # works - # score[i] = torch.full((E, ), 1.0 / (i + 1), dtype=dtype) - #score = torch.empty((M, E), dtype=dtype) - #for i in range(score.numel()): # works - # score.view(-1)[i] = 1.0 / (i + 1) + score = torch.randn((M, E), dtype=dtype) # does not work score = iota((M, E), dtype=dtype) - #p("score", score) - #pp(score) block_n, block_k = block_size[0], block_size[1] n_tiles_w1 = ((2 * N) + block_n - 1) // block_n @@ -541,9 +531,12 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - ref_out = deep_gemm_matmul_w8a8_block_fp8_moe( + ref_out = torch_w8a8_block_fp8_moe( a, w1, w2, w1_s, w2_s, score, topk, block_size) else: + out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, + score, topk, block_size) + ref_out = fused_moe( a, w1, @@ -557,9 +550,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, - score, topk, block_size) - #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 47ab937f7ae8..0fa18aa6f08b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -28,6 +28,23 @@ logger = init_logger(__name__) +use_deep_gemm = False +if True or envs.VLLM_USE_DEEP_GEMM: + try: + import deep_gemm as dg + use_deep_gemm = True + except ImportError: + logger.warning("Failed to import DeepGemm kernels.") + + +def p(s, t): + #print(f"{s}: {t.shape}\n{t}") + pass + +def pp(x): + #print(x) + pass + @triton.jit def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, @@ -511,6 +528,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, # and we can skip some invalid blocks. EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config['BLOCK_SIZE_M']) + grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), ) @@ -764,7 +782,7 @@ def get_default_config( # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 64 if not use_deep_gemm else dg.get_m_alignment_for_contiguous_layout(), "BLOCK_SIZE_N": block_shape[0], "BLOCK_SIZE_K": block_shape[1], "GROUP_SIZE_M": 32, @@ -799,10 +817,11 @@ def get_default_config( "GROUP_SIZE_M": 1, } else: + dg_config = use_deep_gemm and dtype == "fp8_w8a8" config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64 if not dg_config else dg.get_m_alignment_for_contiguous_layout(), + "BLOCK_SIZE_N": 64 if not dg_config else 128, + "BLOCK_SIZE_K": 32 if not dg_config else 128, "GROUP_SIZE_M": 8, } return config @@ -1303,7 +1322,20 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - for chunk in range((num_tokens // CHUNK_SIZE) + 1): + use_dg = False and valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) + + if use_dg: + print("USE_DG!!!!!!!!!!!!!") + num_chunks = 1 + assert w1_scale is not None + assert w2_scale is not None + # TODO: do this offline + w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() + w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() + else: + num_chunks = (num_tokens // CHUNK_SIZE) + 1 + + for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, num_tokens)) @@ -1335,7 +1367,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'] if not use_dg else 1, global_num_experts, expert_map)) invoke_fused_moe_kernel(qcurr_hidden_states, @@ -1396,6 +1428,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) + p("fused topk", topk_ids) + p("fused sorted", sorted_token_ids) + p("fused topk_weight", topk_weights) + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) From 55d9eface1f91fd6d8d62d59b1ec271938a031a5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 6 Mar 2025 21:13:59 +0000 Subject: [PATCH 065/171] not crashing Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 4 ++-- vllm/model_executor/layers/fused_moe/fused_moe.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 2d3bc98d4909..8a9eb674c153 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -455,8 +455,8 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 0fa18aa6f08b..784f567fa713 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -38,11 +38,11 @@ def p(s, t): - #print(f"{s}: {t.shape}\n{t}") + print(f"{s}: {t.shape}\n{t}") pass def pp(x): - #print(x) + print(x) pass @@ -528,7 +528,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, # and we can skip some invalid blocks. EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config['BLOCK_SIZE_M']) - grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), ) @@ -1322,7 +1321,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - use_dg = False and valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) + use_dg = valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) if use_dg: print("USE_DG!!!!!!!!!!!!!") From ae402f53cf10e7d37c9c68cbfafc358351f8106c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 6 Mar 2025 22:37:58 +0000 Subject: [PATCH 066/171] baseline working integration Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 7 ++++--- vllm/model_executor/layers/fused_moe/fused_moe.py | 9 ++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 8a9eb674c153..f6de12d65642 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -33,6 +33,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] +M_moe_dg = [128, 512, 2048] N_moe = [128, 256, 4608] # [128, 4608, 13824] K_moe = [256, 512, 7168] # [256, 7168, 13824] BLOCK_SIZE = [[128, 128]] @@ -369,7 +370,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, sorted_token_ids, m_indices, num_pad = moe_align_block_size( topk_ids, 1, num_groups, None) # topk? - #assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 + assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 pp(f"num_pad = {num_pad}") p("orig sorted", sorted_token_ids) @@ -455,8 +456,8 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 784f567fa713..5748a168540a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -38,11 +38,11 @@ def p(s, t): - print(f"{s}: {t.shape}\n{t}") + #print(f"{s}: {t.shape}\n{t}") pass def pp(x): - print(x) + #print(x) pass @@ -1324,13 +1324,16 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_dg = valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) if use_dg: - print("USE_DG!!!!!!!!!!!!!") + #print("USE_DG!!!!!!!!!!!!!") num_chunks = 1 + CHUNK_SIZE = num_tokens assert w1_scale is not None assert w2_scale is not None # TODO: do this offline + #print("GOT HERE A") w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() + #print("GOT HERE B") else: num_chunks = (num_tokens // CHUNK_SIZE) + 1 From 6587ea189ec2b4c46f89e2f831148350a6308dff Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 6 Mar 2025 22:52:35 +0000 Subject: [PATCH 067/171] add allow_deep_gemm flag Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 1 + .../layers/fused_moe/fused_moe.py | 26 +++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index f6de12d65642..5e968d784c52 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -549,6 +549,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1_scale=w1_s, w2_scale=w2_s, block_shape=block_size, + allow_deep_gemm=True ) #print(f"{out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 5748a168540a..c48dde0bf3c6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1029,13 +1029,14 @@ def inplace_fused_experts(hidden_states: torch.Tensor, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> None: + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape) + block_shape, allow_deep_gemm) def inplace_fused_experts_fake( @@ -1059,7 +1060,8 @@ def inplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> None: + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False) -> None: pass @@ -1093,7 +1095,8 @@ def outplace_fused_experts( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> torch.Tensor: + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, @@ -1123,7 +1126,8 @@ def outplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> torch.Tensor: + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1321,12 +1325,14 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - use_dg = valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) + use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) if use_dg: #print("USE_DG!!!!!!!!!!!!!") - num_chunks = 1 - CHUNK_SIZE = num_tokens + # TODO: how to test chunks? + #num_chunks = 1 + #CHUNK_SIZE = num_tokens + num_chunks = (num_tokens // CHUNK_SIZE) + 1 assert w1_scale is not None assert w2_scale is not None # TODO: do this offline @@ -1337,6 +1343,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: num_chunks = (num_tokens // CHUNK_SIZE) + 1 + if num_chunks > 1: + print("CHUNKS!!!!!!!!!!!!!!!!!!") + for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1467,6 +1476,7 @@ def fused_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of From cc7ec3fec0ab28cf3ccbf3446e8a7c63f470d872 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 7 Mar 2025 21:48:28 +0000 Subject: [PATCH 068/171] wip Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 118 +++++++++++++----- .../layers/fused_moe/fused_moe.py | 14 ++- 2 files changed, 97 insertions(+), 35 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 5e968d784c52..708cf61352d4 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -44,11 +44,11 @@ def p(s, t): - #print(f"{s}: {t.shape}\n{t}") + print(f"{s}: {t.shape}\n{t}") pass def pp(x): - #print(x) + print(x) pass @@ -168,6 +168,48 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) +def torch2_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """Fused moe with block-wise quantization using native torch.""" + B, D = a.shape + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_fp8(a, block_k) + a_q = a_q.to(torch.float32) + + a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) + + assert topk_ids.numel() == a_q.shape[0] == B * topk + + for i in range(w1.shape[0]): + mask = topk_ids == i + print(f"sum = {mask.numel()}, {mask.nonzero()}") + if mask.sum(): + inter_out = native_w8a8_block_fp8_matmul(a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + block_shape, + output_dtype=a.dtype) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_fp8( + act_out, block_k) + act_out = act_out.to(torch.float32) + out[mask] = native_w8a8_block_fp8_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + block_shape, + output_dtype=a.dtype) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + # Skip all tests if CUDA is not available pytest.importorskip("torch.cuda") @@ -360,39 +402,49 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) - inter_out = torch.empty((M * topk, w1[0].shape[0]), - dtype=torch.bfloat16, - device=a.device) - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() _, block_k = block_shape[0], block_shape[1] + #sorted_token_ids, m_indices, num_pad = moe_align_block_size( + # topk_ids, 1, num_groups, None) + sorted_token_ids, m_indices, num_pad = moe_align_block_size( - topk_ids, 1, num_groups, None) # topk? - assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 + topk_ids, M, num_groups, None) + + pp(f"num_pad = {num_pad}, {topk_ids.numel()}, {M*topk}, {M*num_groups}") + + #sorted_token_ids = sorted_token_ids[:num_pad] + pad_size = (m_indices.numel() * M) - sorted_token_ids.numel() + sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", topk*M) + p("sorted_token_ids2", sorted_token_ids) + p("orig m_indices", m_indices) + m_indices = torch.repeat_interleave(m_indices, M, dim=0) #[:num_pad] + + # M * topk + #assert topk_ids.numel() == sorted_token_ids.numel() == num_pad - pp(f"num_pad = {num_pad}") - p("orig sorted", sorted_token_ids) + mask = sorted_token_ids == topk*M # zero out a_q[mask]? + + sorted_token_ids = sorted_token_ids.clamp(max=(M*topk)-1)#[:num_pad] + + assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 - oob_idx = (sorted_token_ids == M*topk).nonzero() - p("oob_idx", oob_idx) - sorted_token_ids = sorted_token_ids.clamp(max=(M*topk)-1)[:M*num_groups] inv_perm = torch.argsort(sorted_token_ids) p("m_indices", m_indices) - assert m_indices.numel() == M * topk + #assert m_indices.numel() == M * topk a_q, a_s = per_token_group_quant_fp8(a, block_m) # Replicate activations and scales - a_q = a_q.view(a_q.shape[0], -1, - a_q.shape[1]).repeat(1, topk, - 1).reshape(-1, a_q.shape[1]) - a_s = a_s.view(a_s.shape[0], -1, - a_s.shape[1]).repeat(1, topk, - 1).reshape(-1, a_s.shape[1]) +# a_q = a_q.view(a_q.shape[0], -1, +# a_q.shape[1]).repeat(1, topk, +# 1).reshape(-1, a_q.shape[1]) +# a_s = a_s.view(a_s.shape[0], -1, +# a_s.shape[1]).repeat(1, topk, +# 1).reshape(-1, a_s.shape[1]) # Permute activations according to sorted token ids a_q = a_q.view(dtype=torch.uint8)[sorted_token_ids, @@ -401,9 +453,16 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, p("topk_ids", topk_ids) p("sorted", sorted_token_ids) - p("m_indices", m_indices) p("topk_weight", topk_weight) + p("a_q", a_q) + p("a_s", a_s) + p("m_indices", m_indices) + + inter_out = torch.zeros((a_q.shape[0], w1[0].shape[0]), + dtype=torch.bfloat16, + device=a.device) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), inter_out, m_indices) @@ -415,7 +474,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, # act_out_q = act_out_q.view(dtype=torch.uint8)[sorted_token_ids, ...].view(dtype=torch.float8_e4m3fn) # act_out_s = act_out_s[sorted_token_ids] - out = torch.empty(act_out.shape[0], + out = torch.zeros(act_out.shape[0], w2.shape[1], dtype=torch.bfloat16, device=a.device) @@ -427,11 +486,11 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #topk_weight = topk_weight[inv_perm] #out[:,num_pad:] = 0 - #p("inter_out", inter_out) - p("out", out) + p("inter_out", inter_out) + #p("out", out) final_out = (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + topk_weight.view(M, -1, 1).to(out.dtype))[:topk*M].sum(dim=1) p("final_out", final_out) @@ -456,8 +515,8 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -485,7 +544,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, fp8_max).clamp(min=fp8_min, max=fp8_max) score = torch.randn((M, E), dtype=dtype) # does not work - score = iota((M, E), dtype=dtype) + #score = iota((M, E), dtype=dtype) block_n, block_k = block_size[0], block_size[1] n_tiles_w1 = ((2 * N) + block_n - 1) // block_n @@ -530,6 +589,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1_scale=w1_s, w2_scale=w2_s, block_shape=block_size, + allow_deep_gemm=False ) ref_out = torch_w8a8_block_fp8_moe( @@ -549,7 +609,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1_scale=w1_s, w2_scale=w2_s, block_shape=block_size, - allow_deep_gemm=True + allow_deep_gemm=False ) #print(f"{out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c48dde0bf3c6..483b869e67aa 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -38,11 +38,11 @@ def p(s, t): - #print(f"{s}: {t.shape}\n{t}") + print(f"{s}: {t.shape}\n{t}") pass def pp(x): - #print(x) + print(x) pass @@ -531,6 +531,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), ) + p("fused a_q", A) + p("fused a_s", A_scale) + p("fused expert ids", expert_ids) + if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 @@ -1402,6 +1406,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) + p("fused inter_out", intermediate_cache1) + if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -1439,10 +1445,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - p("fused topk", topk_ids) - p("fused sorted", sorted_token_ids) - p("fused topk_weight", topk_weights) - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) From da0fd3eba390f0f141e1046df3df62b79ad13a2a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 7 Mar 2025 22:02:44 +0000 Subject: [PATCH 069/171] better Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 23 +++++++++++++------ .../layers/fused_moe/fused_moe.py | 4 ++-- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 708cf61352d4..3b7b34aa91b4 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -44,11 +44,11 @@ def p(s, t): - print(f"{s}: {t.shape}\n{t}") + #print(f"{s}: {t.shape}\n{t}") pass def pp(x): - print(x) + #print(x) pass @@ -393,6 +393,11 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 +# dtype=torch.float8_e4m3fn +def fp8_perm(m, idx): + return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) + + def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" @@ -447,15 +452,16 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, # 1).reshape(-1, a_s.shape[1]) # Permute activations according to sorted token ids - a_q = a_q.view(dtype=torch.uint8)[sorted_token_ids, - ...].view(dtype=torch.float8_e4m3fn) + a_q = fp8_perm(a_q, sorted_token_ids) a_s = a_s[sorted_token_ids] + #a_q.view(dtype=torch.uint8)[mask] = 0 + p("topk_ids", topk_ids) p("sorted", sorted_token_ids) p("topk_weight", topk_weight) - p("a_q", a_q) + p("a_q", fp8_perm(a_q, inv_perm)) p("a_s", a_s) p("m_indices", m_indices) @@ -489,8 +495,11 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, p("inter_out", inter_out) #p("out", out) - final_out = (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype))[:topk*M].sum(dim=1) + #final_out = (out.view(M, -1, w2.shape[1]) * + # topk_weight.view(M, -1, 1).to(out.dtype))[:topk*M].sum(dim=1) + + final_out = (out.view(-1, topk, w2.shape[1])[:topk*M] * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) p("final_out", final_out) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 483b869e67aa..36a0e27c6956 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -38,11 +38,11 @@ def p(s, t): - print(f"{s}: {t.shape}\n{t}") + #print(f"{s}: {t.shape}\n{t}") pass def pp(x): - print(x) + #print(x) pass From 6b08ac7400aa6a8245429cbb5a2a6ce04c6a29e6 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 8 Mar 2025 01:15:42 +0000 Subject: [PATCH 070/171] fix some stuff Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 95 +++++++++++++------ .../layers/fused_moe/fused_moe.py | 2 +- 2 files changed, 69 insertions(+), 28 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 3b7b34aa91b4..6bb1b24b1202 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -44,11 +44,11 @@ def p(s, t): - #print(f"{s}: {t.shape}\n{t}") + print(f"{s}: {t.shape}\n{t}") pass def pp(x): - #print(x) + print(x) pass @@ -411,57 +411,69 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, _, block_k = block_shape[0], block_shape[1] - #sorted_token_ids, m_indices, num_pad = moe_align_block_size( - # topk_ids, 1, num_groups, None) sorted_token_ids, m_indices, num_pad = moe_align_block_size( - topk_ids, M, num_groups, None) + topk_ids, block_m, num_groups, None) pp(f"num_pad = {num_pad}, {topk_ids.numel()}, {M*topk}, {M*num_groups}") #sorted_token_ids = sorted_token_ids[:num_pad] - pad_size = (m_indices.numel() * M) - sorted_token_ids.numel() - sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", topk*M) - p("sorted_token_ids2", sorted_token_ids) - p("orig m_indices", m_indices) - m_indices = torch.repeat_interleave(m_indices, M, dim=0) #[:num_pad] - # M * topk - #assert topk_ids.numel() == sorted_token_ids.numel() == num_pad + print("GOT HERE1") + + num_tokens = topk * M + + pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m) - sorted_token_ids.numel() + if pad_size > 0: + sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", num_tokens) - mask = sorted_token_ids == topk*M # zero out a_q[mask]? + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens-1) + + #m_indices = m_indices[(sorted_token_ids.numel() // 128):] + + p("sorted_token_ids", sorted_token_ids) + p("sorted_token_ids[:num_pad]", sorted_token_ids[:num_pad]) + #sorted_token_ids = sorted_token_ids[:num_pad] + p("orig m_indices", m_indices) + m_indices = torch.repeat_interleave(m_indices, M, dim=0) - sorted_token_ids = sorted_token_ids.clamp(max=(M*topk)-1)#[:num_pad] + print("GOT HERE2") - assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 + assert sorted_token_ids[sorted_token_ids >= num_tokens].sum() == 0 + print("GOT HERE2A") inv_perm = torch.argsort(sorted_token_ids) p("m_indices", m_indices) - #assert m_indices.numel() == M * topk + + print("GOT HERE2B") a_q, a_s = per_token_group_quant_fp8(a, block_m) # Replicate activations and scales -# a_q = a_q.view(a_q.shape[0], -1, -# a_q.shape[1]).repeat(1, topk, -# 1).reshape(-1, a_q.shape[1]) -# a_s = a_s.view(a_s.shape[0], -1, -# a_s.shape[1]).repeat(1, topk, -# 1).reshape(-1, a_s.shape[1]) + a_q = a_q.view(a_q.shape[0], -1, + a_q.shape[1]).repeat(1, topk, + 1).reshape(-1, a_q.shape[1]) + a_s = a_s.view(a_s.shape[0], -1, + a_s.shape[1]).repeat(1, topk, + 1).reshape(-1, a_s.shape[1]) + + print("GOT HERE2C") # Permute activations according to sorted token ids a_q = fp8_perm(a_q, sorted_token_ids) a_s = a_s[sorted_token_ids] + print("GOT HERE3") + #a_q.view(dtype=torch.uint8)[mask] = 0 p("topk_ids", topk_ids) p("sorted", sorted_token_ids) p("topk_weight", topk_weight) - p("a_q", fp8_perm(a_q, inv_perm)) + p("a_q", a_q) p("a_s", a_s) p("m_indices", m_indices) @@ -469,9 +481,15 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, dtype=torch.bfloat16, device=a.device) + + print("GOT HERE4") + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), inter_out, m_indices) + + print("GOT HERE5") + #inter_out = inter_out[inv_perm, ...] act_out = SiluAndMul().forward_native(inter_out) @@ -485,10 +503,17 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, dtype=torch.bfloat16, device=a.device) + print("GOT HERE6") + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) + print("GOT HERE7") + out = out[inv_perm,...] + + print("GOT HERE8") + #topk_weight = topk_weight[inv_perm] #out[:,num_pad:] = 0 @@ -498,8 +523,20 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #final_out = (out.view(M, -1, w2.shape[1]) * # topk_weight.view(M, -1, 1).to(out.dtype))[:topk*M].sum(dim=1) - final_out = (out.view(-1, topk, w2.shape[1])[:topk*M] * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + print(f"GOT HERE9 {out.shape}, {M}, {num_tokens}") + + TT = topk_weight.shape[0] + tmp_out = out.view(-1, topk, w2.shape[1])[:M, ...] + #tmp_out = out[:M, ...].view(M, -1, w2.shape[1]) + + print(f"GOT HERE10 {tmp_out.shape}, {topk_weight.shape}") + + final_out = (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + #final_out = (out.view(-1, topk, w2.shape[1])[:topk*M] * + # topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + print("GOT HERE11") p("final_out", final_out) @@ -521,11 +558,14 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape) - +# topk 6 broken/slow @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", #itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product(M_moe_dg, N_moe, K_moe, E, [1, 2, 4], BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product([512], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([128], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -535,6 +575,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, pytest.skip(f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}") pp(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") + print(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 36a0e27c6956..fb0440b560e1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1382,7 +1382,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'] if not use_dg else 1, + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) invoke_fused_moe_kernel(qcurr_hidden_states, From caa58c03ce0efddb5af3934f154c2dea96ce1d76 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 8 Mar 2025 01:20:07 +0000 Subject: [PATCH 071/171] fix more stuff Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 6bb1b24b1202..ea1b127f3fb6 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -44,11 +44,11 @@ def p(s, t): - print(f"{s}: {t.shape}\n{t}") + #print(f"{s}: {t.shape}\n{t}") pass def pp(x): - print(x) + #print(x) pass @@ -435,7 +435,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, p("sorted_token_ids[:num_pad]", sorted_token_ids[:num_pad]) #sorted_token_ids = sorted_token_ids[:num_pad] p("orig m_indices", m_indices) - m_indices = torch.repeat_interleave(m_indices, M, dim=0) + m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) print("GOT HERE2") @@ -525,7 +525,6 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, print(f"GOT HERE9 {out.shape}, {M}, {num_tokens}") - TT = topk_weight.shape[0] tmp_out = out.view(-1, topk, w2.shape[1])[:M, ...] #tmp_out = out[:M, ...].view(M, -1, w2.shape[1]) @@ -562,8 +561,8 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", #itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product(M_moe_dg, N_moe, K_moe, E, [1, 2, 4], BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product([512], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe_dg, N_moe, K_moe, E, [1, 2, 4], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([512], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([128], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() From 78034ff7737bde313d77a39805cf53f5e32b7f0a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 8 Mar 2025 01:29:57 +0000 Subject: [PATCH 072/171] cleanups Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index ea1b127f3fb6..ab2d3652bde5 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -419,8 +419,6 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #sorted_token_ids = sorted_token_ids[:num_pad] - print("GOT HERE1") - num_tokens = topk * M pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m) - sorted_token_ids.numel() @@ -437,18 +435,12 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, p("orig m_indices", m_indices) m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) - print("GOT HERE2") - assert sorted_token_ids[sorted_token_ids >= num_tokens].sum() == 0 - print("GOT HERE2A") - inv_perm = torch.argsort(sorted_token_ids) p("m_indices", m_indices) - print("GOT HERE2B") - a_q, a_s = per_token_group_quant_fp8(a, block_m) # Replicate activations and scales @@ -459,14 +451,10 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) - print("GOT HERE2C") - # Permute activations according to sorted token ids a_q = fp8_perm(a_q, sorted_token_ids) a_s = a_s[sorted_token_ids] - print("GOT HERE3") - #a_q.view(dtype=torch.uint8)[mask] = 0 p("topk_ids", topk_ids) @@ -482,14 +470,10 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, device=a.device) - print("GOT HERE4") - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), inter_out, m_indices) - print("GOT HERE5") - #inter_out = inter_out[inv_perm, ...] act_out = SiluAndMul().forward_native(inter_out) @@ -503,17 +487,11 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, dtype=torch.bfloat16, device=a.device) - print("GOT HERE6") - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - print("GOT HERE7") - out = out[inv_perm,...] - print("GOT HERE8") - #topk_weight = topk_weight[inv_perm] #out[:,num_pad:] = 0 @@ -523,20 +501,14 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #final_out = (out.view(M, -1, w2.shape[1]) * # topk_weight.view(M, -1, 1).to(out.dtype))[:topk*M].sum(dim=1) - print(f"GOT HERE9 {out.shape}, {M}, {num_tokens}") - tmp_out = out.view(-1, topk, w2.shape[1])[:M, ...] #tmp_out = out[:M, ...].view(M, -1, w2.shape[1]) - print(f"GOT HERE10 {tmp_out.shape}, {topk_weight.shape}") - final_out = (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) #final_out = (out.view(-1, topk, w2.shape[1])[:topk*M] * # topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - print("GOT HERE11") - p("final_out", final_out) # TODO use moe_sum @@ -574,7 +546,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, pytest.skip(f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}") pp(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") - print(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") @@ -592,8 +563,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) - score = torch.randn((M, E), dtype=dtype) # does not work - #score = iota((M, E), dtype=dtype) + score = torch.randn((M, E), dtype=dtype) block_n, block_k = block_size[0], block_size[1] n_tiles_w1 = ((2 * N) + block_n - 1) // block_n From 0549dc27100fac2972cbaa0f0bc8cf32e16c8700 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 8 Mar 2025 04:21:38 +0000 Subject: [PATCH 073/171] some integration tests working Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 44 +++++-------------- .../layers/fused_moe/fused_moe.py | 3 +- 2 files changed, 14 insertions(+), 33 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index ab2d3652bde5..75afbb9b0293 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -188,7 +188,7 @@ def torch2_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): for i in range(w1.shape[0]): mask = topk_ids == i - print(f"sum = {mask.numel()}, {mask.nonzero()}") + #print(f"sum = {mask.numel()}, {mask.nonzero()}") if mask.sum(): inter_out = native_w8a8_block_fp8_matmul(a_q[mask], w1[i], @@ -411,14 +411,11 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, _, block_k = block_shape[0], block_shape[1] - sorted_token_ids, m_indices, num_pad = moe_align_block_size( topk_ids, block_m, num_groups, None) pp(f"num_pad = {num_pad}, {topk_ids.numel()}, {M*topk}, {M*num_groups}") - #sorted_token_ids = sorted_token_ids[:num_pad] - num_tokens = topk * M pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m) - sorted_token_ids.numel() @@ -427,11 +424,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, sorted_token_ids = sorted_token_ids.clamp(max=num_tokens-1) - #m_indices = m_indices[(sorted_token_ids.numel() // 128):] - p("sorted_token_ids", sorted_token_ids) - p("sorted_token_ids[:num_pad]", sorted_token_ids[:num_pad]) - #sorted_token_ids = sorted_token_ids[:num_pad] p("orig m_indices", m_indices) m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) @@ -439,8 +432,6 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, inv_perm = torch.argsort(sorted_token_ids) - p("m_indices", m_indices) - a_q, a_s = per_token_group_quant_fp8(a, block_m) # Replicate activations and scales @@ -455,8 +446,6 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_q = fp8_perm(a_q, sorted_token_ids) a_s = a_s[sorted_token_ids] - #a_q.view(dtype=torch.uint8)[mask] = 0 - p("topk_ids", topk_ids) p("sorted", sorted_token_ids) p("topk_weight", topk_weight) @@ -469,19 +458,15 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, dtype=torch.bfloat16, device=a.device) + #print(f"inter_out {inter_out.shape}") deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), inter_out, m_indices) - #inter_out = inter_out[inv_perm, ...] - act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) -# act_out_q = act_out_q.view(dtype=torch.uint8)[sorted_token_ids, ...].view(dtype=torch.float8_e4m3fn) -# act_out_s = act_out_s[sorted_token_ids] - out = torch.zeros(act_out.shape[0], w2.shape[1], dtype=torch.bfloat16, @@ -492,22 +477,16 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, out = out[inv_perm,...] - #topk_weight = topk_weight[inv_perm] - #out[:,num_pad:] = 0 - p("inter_out", inter_out) - #p("out", out) - - #final_out = (out.view(M, -1, w2.shape[1]) * - # topk_weight.view(M, -1, 1).to(out.dtype))[:topk*M].sum(dim=1) tmp_out = out.view(-1, topk, w2.shape[1])[:M, ...] - #tmp_out = out[:M, ...].view(M, -1, w2.shape[1]) + + #print(f"tk {topk_weight.shape}, M={M} topk={topk}, N={w2.shape[1]}, out_C={out.shape}") + #print(f"tmp_out {tmp_out.shape}") final_out = (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - #final_out = (out.view(-1, topk, w2.shape[1])[:topk*M] * - # topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + #print(f"final_out {final_out.shape}") p("final_out", final_out) @@ -546,6 +525,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, pytest.skip(f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}") pp(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") + #print(f"\n\n\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") @@ -597,6 +577,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, with set_current_vllm_config(vllm_config): if False: + ref_out = torch_w8a8_block_fp8_moe( + a, w1, w2, w1_s, w2_s, score, topk, block_size) + out = fused_moe( a, w1, @@ -608,11 +591,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1_scale=w1_s, w2_scale=w2_s, block_shape=block_size, - allow_deep_gemm=False + allow_deep_gemm=True ) - - ref_out = torch_w8a8_block_fp8_moe( - a, w1, w2, w1_s, w2_s, score, topk, block_size) else: out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_size) @@ -628,7 +608,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1_scale=w1_s, w2_scale=w2_s, block_shape=block_size, - allow_deep_gemm=False + allow_deep_gemm=True ) #print(f"{out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fb0440b560e1..9f4f521752b3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1361,6 +1361,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, break if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + assert False # for now # Adjust the intermediate cache size and config for the last # chunk. Note that in most cases we only have one chunk # so the cache size and config are already set correctly and @@ -1382,7 +1383,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + moe_align_block_size(curr_topk_ids, block_m, global_num_experts, expert_map)) invoke_fused_moe_kernel(qcurr_hidden_states, From 14d05696390e4b8a52ce8f46e71fa4af0dfbe5d0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 10 Mar 2025 16:59:11 +0000 Subject: [PATCH 074/171] almost all tests passing Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 3 +- .../layers/fused_moe/fused_moe.py | 41 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 75afbb9b0293..142fd368083e 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -516,6 +516,7 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: #itertools.product([512], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([128], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([512], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -576,7 +577,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_s = w2_sa with set_current_vllm_config(vllm_config): - if False: + if True: ref_out = torch_w8a8_block_fp8_moe( a, w1, w2, w1_s, w2_s, score, topk, block_size) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 9f4f521752b3..c625b5f682dc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1302,6 +1302,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, config = get_config_func(M) + # XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX + # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 cache13 = torch.empty(M * top_k_num * max(N, K), @@ -1315,6 +1317,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, device=hidden_states.device, dtype=hidden_states.dtype) + # XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX + if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 elif hidden_states.dtype == torch.float16: @@ -1331,6 +1335,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) + block_m = config['BLOCK_SIZE_M'] + assert not use_dg or block_m == 128 + if use_dg: #print("USE_DG!!!!!!!!!!!!!") # TODO: how to test chunks? @@ -1344,7 +1351,41 @@ def fused_experts_impl(hidden_states: torch.Tensor, w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() #print("GOT HERE B") + + # BIG HACK + sorted_token_ids, _, _ = ( + moe_align_block_size(topk_ids, block_m, + global_num_experts, expert_map)) + + num_tokens = top_k_num * M + pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m) - sorted_token_ids.numel() + if pad_size > 0: + sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", num_tokens) + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens-1) + + new_M = sorted_token_ids.numel()//top_k_num + #print(f"fused2 m={M}, new_M={new_M}, sort={sorted_token_ids.shape}, hs={hidden_states.shape}, hs[sort]={hidden_states.view(num_tokens, -1)[sorted_token_ids, ...].shape}") + + intermediate_cache1 = torch.empty((new_M, top_k_num, N), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty((new_M * top_k_num, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = torch.empty((new_M, top_k_num, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) else: + intermediate_cache1 = torch.empty((M, top_k_num, N), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty((M * top_k_num, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + num_chunks = (num_tokens // CHUNK_SIZE) + 1 if num_chunks > 1: From ac2a339b468053a2cacd4bca4e99e4be4eb1894b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 10 Mar 2025 18:19:10 +0000 Subject: [PATCH 075/171] cleanup temp construction a bit Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 4 ++-- vllm/model_executor/layers/fused_moe/fused_moe.py | 13 ++++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 142fd368083e..828f258a877f 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -512,9 +512,9 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", #itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product(M_moe_dg, N_moe, K_moe, E, [1, 2, 4], BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe_dg, N_moe, K_moe, E, [1, 2, 4], BLOCK_SIZE, DTYPES, SEEDS)) # all work #itertools.product([512], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([128], [128], [256], [8], [6], BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([128], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([512], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c625b5f682dc..1fa5ebee78d9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1353,7 +1353,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, #print("GOT HERE B") # BIG HACK - sorted_token_ids, _, _ = ( + sorted_token_ids, _, pad = ( moe_align_block_size(topk_ids, block_m, global_num_experts, expert_map)) @@ -1363,8 +1363,15 @@ def fused_experts_impl(hidden_states: torch.Tensor, sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", num_tokens) sorted_token_ids = sorted_token_ids.clamp(max=num_tokens-1) - new_M = sorted_token_ids.numel()//top_k_num - #print(f"fused2 m={M}, new_M={new_M}, sort={sorted_token_ids.shape}, hs={hidden_states.shape}, hs[sort]={hidden_states.view(num_tokens, -1)[sorted_token_ids, ...].shape}") + #new_M = sorted_token_ids.numel()//top_k_num + #print(f"fused2 m={M}, sort={sorted_token_ids.shape}, pad={pad}, hs={hidden_states.shape}, num_tok={num_tokens}") + #print(f"hs[sort]={torch.repeat_interleave(hidden_states, top_k_num, dim=0)[sorted_token_ids, ...].shape}") + new_S = torch.repeat_interleave(hidden_states, top_k_num, dim=0)[sorted_token_ids, ...].shape + #new_top_k = new_S[0] // M + new_M = new_S[0] // top_k_num + #new_M = ((new_M + block_m - 1) // block_m) * block_m + #print(f"fused2 new_M_b={new_M} top_k = {top_k_num}, new_top_k={new_top_k}") + #top_k_num = new_top_k intermediate_cache1 = torch.empty((new_M, top_k_num, N), device=hidden_states.device, From d87b30588b7efcfe119014cf79ae2ff3c3865136 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 10 Mar 2025 20:09:59 +0000 Subject: [PATCH 076/171] fix rest of tests Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 8 +--- .../layers/fused_moe/fused_moe.py | 38 +++++++------------ 2 files changed, 14 insertions(+), 32 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 828f258a877f..1bd4d20f1124 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -511,12 +511,7 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: # topk 6 broken/slow @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - #itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product(M_moe_dg, N_moe, K_moe, E, [1, 2, 4], BLOCK_SIZE, DTYPES, SEEDS)) # all work - #itertools.product([512], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([128], [128], [256], [8], [6], BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([128], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([512], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -526,7 +521,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, pytest.skip(f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}") pp(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") - #print(f"\n\n\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1fa5ebee78d9..5454333ce290 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -531,10 +531,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), ) - p("fused a_q", A) - p("fused a_s", A_scale) - p("fused expert ids", expert_ids) - if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 @@ -1339,20 +1335,22 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert not use_dg or block_m == 128 if use_dg: - #print("USE_DG!!!!!!!!!!!!!") # TODO: how to test chunks? - #num_chunks = 1 - #CHUNK_SIZE = num_tokens - num_chunks = (num_tokens // CHUNK_SIZE) + 1 + if False: + num_chunks = 1 + CHUNK_SIZE = num_tokens + else: + num_chunks = (num_tokens // CHUNK_SIZE) + 1 + assert w1_scale is not None assert w2_scale is not None + # TODO: do this offline - #print("GOT HERE A") w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - #print("GOT HERE B") - # BIG HACK + + # TODO: this could be smarter sorted_token_ids, _, pad = ( moe_align_block_size(topk_ids, block_m, global_num_experts, expert_map)) @@ -1362,24 +1360,16 @@ def fused_experts_impl(hidden_states: torch.Tensor, if pad_size > 0: sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", num_tokens) sorted_token_ids = sorted_token_ids.clamp(max=num_tokens-1) - - #new_M = sorted_token_ids.numel()//top_k_num - #print(f"fused2 m={M}, sort={sorted_token_ids.shape}, pad={pad}, hs={hidden_states.shape}, num_tok={num_tokens}") - #print(f"hs[sort]={torch.repeat_interleave(hidden_states, top_k_num, dim=0)[sorted_token_ids, ...].shape}") new_S = torch.repeat_interleave(hidden_states, top_k_num, dim=0)[sorted_token_ids, ...].shape - #new_top_k = new_S[0] // M - new_M = new_S[0] // top_k_num - #new_M = ((new_M + block_m - 1) // block_m) * block_m - #print(f"fused2 new_M_b={new_M} top_k = {top_k_num}, new_top_k={new_top_k}") - #top_k_num = new_top_k + new_M = new_S[0] - intermediate_cache1 = torch.empty((new_M, top_k_num, N), + intermediate_cache1 = torch.empty((new_M, N), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache2 = torch.empty((new_M * top_k_num, N // 2), + intermediate_cache2 = torch.empty((new_M, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((new_M, top_k_num, w2.shape[1]), + intermediate_cache3 = torch.empty((new_M, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype) else: @@ -1455,8 +1445,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - p("fused inter_out", intermediate_cache1) - if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) From 7fcdd1cf70ecc908889a966fa8c4553489a336a9 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 10 Mar 2025 20:52:10 +0000 Subject: [PATCH 077/171] cleanups + format Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 212 ++++++------------ .../layers/fused_moe/fused_moe.py | 64 ++++-- 2 files changed, 103 insertions(+), 173 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 1bd4d20f1124..3fe432e61b15 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -1,11 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # Adapted from https://github.com/sgl-project/sglang/pull/2575 -# TODO: try/catch this? + import itertools from typing import Tuple -import deep_gemm +dg_available = False +try: + import deep_gemm + dg_available = True +except: + pass + import pytest import torch @@ -28,30 +34,21 @@ D = [512, 4096, 5120, 13824] GROUP_SIZE = [64, 128, 256, 512] M = [1, 7, 8, 83, 84, 512, 2048, 4096] -N = [128, 512, 1024, 4096, 7748, 13824, 7168] +N = [128, 512, 1024, 4096, 7168, 7748, 13824] K = [256, 4096, 5120, 3884, 13824, 16384] # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] M_moe_dg = [128, 512, 2048] -N_moe = [128, 256, 4608] # [128, 4608, 13824] -K_moe = [256, 512, 7168] # [256, 7168, 13824] +N_moe = [128, 256, 4608] # [13824] +K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] -E = [2, 8, 16, 24] +E = [2, 8, 16, 24] # [128, 256] TOP_KS = [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] -def p(s, t): - #print(f"{s}: {t.shape}\n{t}") - pass - -def pp(x): - #print(x) - pass - - def native_per_token_group_quant_fp8(x, group_size, eps=1e-10, @@ -168,48 +165,6 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) -def torch2_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): - """Fused moe with block-wise quantization using native torch.""" - B, D = a.shape - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - - _, block_k = block_shape[0], block_shape[1] - a_q, a_s = native_per_token_group_quant_fp8(a, block_k) - a_q = a_q.to(torch.float32) - - a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) - - assert topk_ids.numel() == a_q.shape[0] == B * topk - - for i in range(w1.shape[0]): - mask = topk_ids == i - #print(f"sum = {mask.numel()}, {mask.nonzero()}") - if mask.sum(): - inter_out = native_w8a8_block_fp8_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_fp8( - act_out, block_k) - act_out = act_out.to(torch.float32) - out[mask] = native_w8a8_block_fp8_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - - # Skip all tests if CUDA is not available pytest.importorskip("torch.cuda") @@ -306,6 +261,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) + # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): out = fused_moe( a, @@ -393,14 +349,13 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 -# dtype=torch.float8_e4m3fn def fp8_perm(m, idx): - return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) + return m.view(dtype=torch.uint8)[idx, ...].view(dtype=torch.float8_e4m3fn) def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_shape): - """Fused moe with block-wise quantization using DeepGemm torch.""" + """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" num_groups = w1.shape[0] M, K = a.shape N = w2.shape[-1] @@ -414,18 +369,17 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, sorted_token_ids, m_indices, num_pad = moe_align_block_size( topk_ids, block_m, num_groups, None) - pp(f"num_pad = {num_pad}, {topk_ids.numel()}, {M*topk}, {M*num_groups}") - num_tokens = topk * M - pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m) - sorted_token_ids.numel() + pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * + block_m) - sorted_token_ids.numel() if pad_size > 0: - sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", num_tokens) + sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, + (0, pad_size), "constant", + num_tokens) - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens-1) + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - p("sorted_token_ids", sorted_token_ids) - p("orig m_indices", m_indices) m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) assert sorted_token_ids[sorted_token_ids >= num_tokens].sum() == 0 @@ -436,34 +390,21 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, # Replicate activations and scales a_q = a_q.view(a_q.shape[0], -1, - a_q.shape[1]).repeat(1, topk, - 1).reshape(-1, a_q.shape[1]) + a_q.shape[1]).repeat(1, topk, 1).reshape(-1, a_q.shape[1]) a_s = a_s.view(a_s.shape[0], -1, - a_s.shape[1]).repeat(1, topk, - 1).reshape(-1, a_s.shape[1]) + a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) # Permute activations according to sorted token ids a_q = fp8_perm(a_q, sorted_token_ids) a_s = a_s[sorted_token_ids] - p("topk_ids", topk_ids) - p("sorted", sorted_token_ids) - p("topk_weight", topk_weight) - - p("a_q", a_q) - p("a_s", a_s) - p("m_indices", m_indices) - inter_out = torch.zeros((a_q.shape[0], w1[0].shape[0]), dtype=torch.bfloat16, device=a.device) - #print(f"inter_out {inter_out.shape}") - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), inter_out, m_indices) - act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) @@ -475,54 +416,31 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - out = out[inv_perm,...] - - p("inter_out", inter_out) - - tmp_out = out.view(-1, topk, w2.shape[1])[:M, ...] + out = out[inv_perm, ...] - #print(f"tk {topk_weight.shape}, M={M} topk={topk}, N={w2.shape[1]}, out_C={out.shape}") - #print(f"tmp_out {tmp_out.shape}") + tmp_out = out[:(M * topk), ...].view(-1, topk, w2.shape[1]) final_out = (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - #print(f"final_out {final_out.shape}") - - p("final_out", final_out) - - # TODO use moe_sum + # TODO use moe_sum? return final_out -def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: - dimensions = [] - - for index, _ in enumerate(shape): - if index != dim: - dimension = 1 - else: - dimension = shape[index] - - dimensions = [*dimensions, dimension] - - return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape) - -# topk 6 broken/slow @pytest.mark.parametrize( - "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + "M,N,K,E,topk,block_size,dtype,seed,test_baseline", + itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS, [True, False])) +@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, - dtype, seed): + dtype, seed, test_baseline): # only aligned sizes if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0 or topk > E): - pytest.skip(f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}") - - pp(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") - - torch.set_printoptions(profile="full") + pytest.skip( + f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}" + ) vllm_config = VllmConfig() @@ -571,40 +489,36 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_s = w2_sa with set_current_vllm_config(vllm_config): - if True: - ref_out = torch_w8a8_block_fp8_moe( - a, w1, w2, w1_s, w2_s, score, topk, block_size) - - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - allow_deep_gemm=True - ) + if not test_baseline: + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) + + out = fused_moe(a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + allow_deep_gemm=True) else: - out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, - score, topk, block_size) - - ref_out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - allow_deep_gemm=True - ) + ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, + score, topk, block_size) + + out = fused_moe(a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + allow_deep_gemm=True) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 5454333ce290..66d3ef24ffba 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -41,6 +41,7 @@ def p(s, t): #print(f"{s}: {t.shape}\n{t}") pass + def pp(x): #print(x) pass @@ -781,12 +782,19 @@ def get_default_config( # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": 64 if not use_deep_gemm else dg.get_m_alignment_for_contiguous_layout(), - "BLOCK_SIZE_N": block_shape[0], - "BLOCK_SIZE_K": block_shape[1], - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3, + "BLOCK_SIZE_M": + 64 if not use_deep_gemm else + dg.get_m_alignment_for_contiguous_layout(), + "BLOCK_SIZE_N": + block_shape[0], + "BLOCK_SIZE_K": + block_shape[1], + "GROUP_SIZE_M": + 32, + "num_warps": + 4, + "num_stages": + 3, } elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None: # moe wna16 kernels @@ -818,10 +826,15 @@ def get_default_config( else: dg_config = use_deep_gemm and dtype == "fp8_w8a8" config = { - "BLOCK_SIZE_M": 64 if not dg_config else dg.get_m_alignment_for_contiguous_layout(), - "BLOCK_SIZE_N": 64 if not dg_config else 128, - "BLOCK_SIZE_K": 32 if not dg_config else 128, - "GROUP_SIZE_M": 8, + "BLOCK_SIZE_M": + 64 + if not dg_config else dg.get_m_alignment_for_contiguous_layout(), + "BLOCK_SIZE_N": + 64 if not dg_config else 128, + "BLOCK_SIZE_K": + 32 if not dg_config else 128, + "GROUP_SIZE_M": + 8, } return config @@ -1329,14 +1342,15 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) + use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, config, + use_fp8_w8a8) block_m = config['BLOCK_SIZE_M'] assert not use_dg or block_m == 128 if use_dg: # TODO: how to test chunks? - if False: + if True: num_chunks = 1 CHUNK_SIZE = num_tokens else: @@ -1349,18 +1363,21 @@ def fused_experts_impl(hidden_states: torch.Tensor, w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - # TODO: this could be smarter - sorted_token_ids, _, pad = ( - moe_align_block_size(topk_ids, block_m, - global_num_experts, expert_map)) + sorted_token_ids, _, _ = (moe_align_block_size(topk_ids, block_m, + global_num_experts, + expert_map)) num_tokens = top_k_num * M - pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m) - sorted_token_ids.numel() + pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * + block_m) - sorted_token_ids.numel() if pad_size > 0: - sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", num_tokens) - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens-1) - new_S = torch.repeat_interleave(hidden_states, top_k_num, dim=0)[sorted_token_ids, ...].shape + sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, + (0, pad_size), + "constant", num_tokens) + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) + new_S = torch.repeat_interleave(hidden_states, top_k_num, + dim=0)[sorted_token_ids, ...].shape new_M = new_S[0] intermediate_cache1 = torch.empty((new_M, N), @@ -1399,7 +1416,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, break if tokens_in_chunk < CHUNK_SIZE and chunk > 0: - assert False # for now + assert not use_dg # for now # Adjust the intermediate cache size and config for the last # chunk. Note that in most cases we only have one chunk # so the cache size and config are already set correctly and @@ -1421,8 +1438,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, block_m, - global_num_experts, expert_map)) + moe_align_block_size(curr_topk_ids, block_m, global_num_experts, + expert_map)) invoke_fused_moe_kernel(qcurr_hidden_states, w1, @@ -1484,7 +1501,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) - return out_hidden_states From ed3610e18295b3cfdfe2a8fea5ee2c517590cc00 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 10 Mar 2025 22:52:59 +0000 Subject: [PATCH 078/171] do more of output computation in place Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 66d3ef24ffba..8c8ecd5e0a49 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1531,7 +1531,7 @@ def fused_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, - allow_deep_gemm: bool = False, + allow_deep_gemm: bool = True, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of From e39f8c88e328014b5402600a5bd1c838a7ce8d34 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 10 Mar 2025 22:57:19 +0000 Subject: [PATCH 079/171] add env var Signed-off-by: Bill Nell --- .../model_executor/layers/fused_moe/fused_moe.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 8c8ecd5e0a49..3fb293ffa46b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -29,24 +29,15 @@ logger = init_logger(__name__) use_deep_gemm = False -if True or envs.VLLM_USE_DEEP_GEMM: +if envs.VLLM_USE_DEEP_GEMM: try: import deep_gemm as dg + logger.info("Using DeepGemm for fused MoE.") use_deep_gemm = True except ImportError: logger.warning("Failed to import DeepGemm kernels.") -def p(s, t): - #print(f"{s}: {t.shape}\n{t}") - pass - - -def pp(x): - #print(x) - pass - - @triton.jit def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, @@ -1402,9 +1393,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_chunks = (num_tokens // CHUNK_SIZE) + 1 - if num_chunks > 1: - print("CHUNKS!!!!!!!!!!!!!!!!!!") - for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, From adf85f17141c0a57ff859c70ec3520343ff1c4cf Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 12 Mar 2025 04:23:27 +0000 Subject: [PATCH 080/171] formatting, remove some blocking restrictions Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 24 +++++++++---------- .../layers/fused_moe/fused_moe.py | 15 ++++++------ .../compressed_tensors_moe.py | 2 ++ .../model_executor/layers/quantization/fp8.py | 1 + 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 3fe432e61b15..9e6bfc4018e6 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -1,17 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # Adapted from https://github.com/sgl-project/sglang/pull/2575 - import itertools from typing import Tuple -dg_available = False -try: - import deep_gemm - dg_available = True -except: - pass - import pytest import torch @@ -24,6 +16,13 @@ per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform +dg_available = False +try: + import deep_gemm + dg_available = True +except ImportError: + pass + if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) @@ -39,7 +38,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [128, 512, 2048] +M_moe_dg = [128, 192, 512, 2048] N_moe = [128, 256, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] @@ -358,7 +357,6 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" num_groups = w1.shape[0] M, K = a.shape - N = w2.shape[-1] topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) @@ -437,10 +435,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed, test_baseline): # only aligned sizes - if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0 or topk > E): + if ((M % 128 != 0 and test_baseline) or N % 128 != 0 or K % 128 != 0 + or topk > E): pytest.skip( - f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}" - ) + f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") vllm_config = VllmConfig() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3fb293ffa46b..2dae10988ac5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1333,24 +1333,23 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, config, + use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, use_fp8_w8a8) block_m = config['BLOCK_SIZE_M'] assert not use_dg or block_m == 128 if use_dg: - # TODO: how to test chunks? - if True: - num_chunks = 1 - CHUNK_SIZE = num_tokens - else: - num_chunks = (num_tokens // CHUNK_SIZE) + 1 + if M % 128 != 0: + CHUNK_SIZE = (M // 128) * 128 + num_chunks = (num_tokens // CHUNK_SIZE) + 1 assert w1_scale is not None assert w2_scale is not None - # TODO: do this offline + # We attempt to do this offline in Fp8MoEMethod, in which case these + # calls will be nops. Otherwise, they'll be performed every time the + # layer is executed. w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 721e36af2b28..751f1aaf2c38 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -272,6 +272,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: from vllm.model_executor.layers.fused_moe import fused_experts self.fused_experts_func = fused_experts + # TODO: do we need to do deep gemm alignment here? + def apply( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5515ba27ea19..e184c3173a4c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -443,6 +443,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None + self.allow_deep_gemm = use_deep_gemm # Check for DeepGemm support. self.allow_deep_gemm = False From 8e931603fc5ddb5436f3cb5d25d3a7d9daace798 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 12 Mar 2025 12:59:57 +0000 Subject: [PATCH 081/171] wip Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 2 +- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 9e6bfc4018e6..036baace5d77 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -38,7 +38,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [128, 192, 512, 2048] +M_moe_dg = [128, 512, 2048] #192 N_moe = [128, 256, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2dae10988ac5..c645411a4181 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1340,6 +1340,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert not use_dg or block_m == 128 if use_dg: + #print("USE_DG") if M % 128 != 0: CHUNK_SIZE = (M // 128) * 128 num_chunks = (num_tokens // CHUNK_SIZE) + 1 @@ -1354,11 +1355,11 @@ def fused_experts_impl(hidden_states: torch.Tensor, w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() # TODO: this could be smarter + num_tokens = top_k_num * M sorted_token_ids, _, _ = (moe_align_block_size(topk_ids, block_m, global_num_experts, expert_map)) - num_tokens = top_k_num * M pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m) - sorted_token_ids.numel() if pad_size > 0: @@ -1368,6 +1369,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) new_S = torch.repeat_interleave(hidden_states, top_k_num, dim=0)[sorted_token_ids, ...].shape + #new_M = hidden_states.shape[0] * top_k_num * global_num_experts new_M = new_S[0] intermediate_cache1 = torch.empty((new_M, N), From d81062bc100a277ff1600215d050f70c56243a9b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 12 Mar 2025 13:19:08 +0000 Subject: [PATCH 082/171] fix resizing of output Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c645411a4181..9e2df09cf3f2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -28,14 +28,12 @@ logger = init_logger(__name__) -use_deep_gemm = False -if envs.VLLM_USE_DEEP_GEMM: - try: - import deep_gemm as dg - logger.info("Using DeepGemm for fused MoE.") - use_deep_gemm = True - except ImportError: - logger.warning("Failed to import DeepGemm kernels.") +has_deep_gemm = False +try: + import deep_gemm as dg + has_deep_gemm = True +except ImportError: + pass @triton.jit @@ -768,6 +766,7 @@ def get_default_config( dtype: Optional[str], is_marlin: bool, block_shape: Optional[List[int]] = None, + use_deep_gemm: bool = False, ) -> Dict[str, int]: if dtype == "fp8_w8a8" and block_shape is not None: # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] @@ -838,6 +837,7 @@ def try_get_optimal_moe_config( M: int, is_marlin: bool = False, block_shape: Optional[List[int]] = None, + use_deep_gemm: bool = False, ): from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() @@ -859,7 +859,7 @@ def try_get_optimal_moe_config( else: # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, - is_marlin, block_shape) + is_marlin, block_shape, use_deep_gemm) return config @@ -1298,6 +1298,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, top_k_num, config_dtype, block_shape=block_shape, + use_deep_gemm=has_deep_gemm and allow_deep_gemm, # hacky ) config = get_config_func(M) @@ -1340,8 +1341,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert not use_dg or block_m == 128 if use_dg: - #print("USE_DG") - if M % 128 != 0: + if False and M % 128 != 0: CHUNK_SIZE = (M // 128) * 128 num_chunks = (num_tokens // CHUNK_SIZE) + 1 From b2ea85ca230e8dbdae94e2380a2ea35ddf1cc5bb Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 12 Mar 2025 13:19:41 +0000 Subject: [PATCH 083/171] fix resizing of output Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 2 +- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 036baace5d77..9e6bfc4018e6 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -38,7 +38,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [128, 512, 2048] #192 +M_moe_dg = [128, 192, 512, 2048] N_moe = [128, 256, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 9e2df09cf3f2..38641e529389 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1341,7 +1341,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert not use_dg or block_m == 128 if use_dg: - if False and M % 128 != 0: + if M % 128 != 0: CHUNK_SIZE = (M // 128) * 128 num_chunks = (num_tokens // CHUNK_SIZE) + 1 From 37053bd5694d8d9bb95bb8142363f161a5f68cea Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 12 Mar 2025 13:32:40 +0000 Subject: [PATCH 084/171] fixes Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 2 +- .../layers/fused_moe/fused_moe.py | 18 +++++------------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 9e6bfc4018e6..4855fdb69952 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -38,7 +38,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [128, 192, 512, 2048] +M_moe_dg = [128, 512, 2048] # 192 N_moe = [128, 256, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 38641e529389..6f6e98356799 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1354,23 +1354,12 @@ def fused_experts_impl(hidden_states: torch.Tensor, w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - # TODO: this could be smarter - num_tokens = top_k_num * M + # TODO: computing new_M could be smarter sorted_token_ids, _, _ = (moe_align_block_size(topk_ids, block_m, global_num_experts, expert_map)) - pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * - block_m) - sorted_token_ids.numel() - if pad_size > 0: - sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, - (0, pad_size), - "constant", num_tokens) - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - new_S = torch.repeat_interleave(hidden_states, top_k_num, - dim=0)[sorted_token_ids, ...].shape - #new_M = hidden_states.shape[0] * top_k_num * global_num_experts - new_M = new_S[0] + new_M = ((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m intermediate_cache1 = torch.empty((new_M, N), device=hidden_states.device, @@ -1394,6 +1383,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_chunks = (num_tokens // CHUNK_SIZE) + 1 + # TODO: modify CHUNK_SIZE to be % 128 == 0 and check if each chunk is + # valid dg. fall back to old kernel if not + for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, From bcb245a31cc67e787f2b7f3cf8aa29940707803e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 12 Mar 2025 22:29:32 +0000 Subject: [PATCH 085/171] aligned chunking working for deep gemm Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 9 ++--- .../layers/fused_moe/fused_moe.py | 34 ++++++++++++------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 4855fdb69952..599909a7056f 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -427,15 +427,16 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed,test_baseline", - itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS, [True, False])) + itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS, [True, False])) + #itertools.product([254],[128],[256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS, [True])) + #itertools.product([512],[128],[256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS, [True, False])) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed, test_baseline): # only aligned sizes - if ((M % 128 != 0 and test_baseline) or N % 128 != 0 or K % 128 != 0 + if ((M % 128 != 0 and not test_baseline) or N % 128 != 0 or K % 128 != 0 or topk > E): pytest.skip( f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") @@ -487,7 +488,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_s = w2_sa with set_current_vllm_config(vllm_config): - if not test_baseline: + if test_baseline: ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6f6e98356799..ad4ff41d6414 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1340,10 +1340,16 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_m = config['BLOCK_SIZE_M'] assert not use_dg or block_m == 128 + chunked_dg = False if use_dg: + #print("USE_DG") + #CHUNK_SIZE = 128 if M % 128 != 0: CHUNK_SIZE = (M // 128) * 128 + #print(f"DG_CHUNK {CHUNK_SIZE}") + num_chunks = (num_tokens // CHUNK_SIZE) + 1 + chunked_dg = num_chunks > 1 assert w1_scale is not None assert w2_scale is not None @@ -1393,20 +1399,12 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.shape + skip_dg = tokens_in_chunk % 128 != 0 + if tokens_in_chunk == 0: break - if tokens_in_chunk < CHUNK_SIZE and chunk > 0: - assert not use_dg # for now - # Adjust the intermediate cache size and config for the last - # chunk. Note that in most cases we only have one chunk - # so the cache size and config are already set correctly and - # do not need to be adjusted. - intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] - intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * - topk_ids.shape[1]] - intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] - config = get_config_func(tokens_in_chunk) + #print(f"LOOP skip={skip_dg} tic={tokens_in_chunk}, chunk={chunk}") curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] @@ -1480,8 +1478,20 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), + if use_dg and not skip_dg: + assert inv_perm is not None + M = curr_topk_weights.shape[0] + out_C = intermediate_cache3[inv_perm, ...] + out_C = out_C[:(M * top_k_num), ...] + out_C = out_C.view(-1, top_k_num, w2.shape[1]) + out_C.mul_(curr_topk_weights.view(M, -1, 1)) + tmp_cache3 = out_C + else: + tmp_cache3 = intermediate_cache3.view(*intermediate_cache3.shape) + + ops.moe_sum(tmp_cache3, out_hidden_states[begin_chunk_idx:end_chunk_idx]) + return out_hidden_states From f585c5d4e39ca18a395ea13718beecb5e2cba788 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 13 Mar 2025 00:03:55 +0000 Subject: [PATCH 086/171] unaligned chunking for deep gemm Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 4 +--- vllm/model_executor/layers/fused_moe/fused_moe.py | 7 +------ 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 599909a7056f..b4fe9135a887 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -38,7 +38,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [128, 512, 2048] # 192 +M_moe_dg = [128, 192, 512, 1335, 2048] N_moe = [128, 256, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] @@ -428,8 +428,6 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed,test_baseline", itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS, [True, False])) - #itertools.product([254],[128],[256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS, [True])) - #itertools.product([512],[128],[256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS, [True, False])) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ad4ff41d6414..fecefea8b062 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1342,11 +1342,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, chunked_dg = False if use_dg: - #print("USE_DG") - #CHUNK_SIZE = 128 if M % 128 != 0: - CHUNK_SIZE = (M // 128) * 128 - #print(f"DG_CHUNK {CHUNK_SIZE}") + CHUNK_SIZE = (M // 128) * 128 # min with env? num_chunks = (num_tokens // CHUNK_SIZE) + 1 chunked_dg = num_chunks > 1 @@ -1404,8 +1401,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, if tokens_in_chunk == 0: break - #print(f"LOOP skip={skip_dg} tic={tokens_in_chunk}, chunk={chunk}") - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] From 6dd17e5fa4dd4850ef7eaab01ba4a529cec84689 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 13 Mar 2025 17:17:27 +0000 Subject: [PATCH 087/171] cleanup wip Signed-off-by: Bill Nell --- requirements/test.txt | 6 ++ tests/kernels/test_block_fp8.py | 3 +- .../layers/fused_moe/fused_moe.py | 60 ++++++++----------- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index 2e8121e3882e..03093e134524 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -132,6 +132,10 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval +exceptiongroup==1.2.2 + # via + # anyio + # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -772,9 +776,11 @@ typing-extensions==4.12.2 # huggingface-hub # librosa # mistral-common + # multidict # pqdm # pydantic # pydantic-core + # rich # torch # typer tzdata==2024.2 diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index b4fe9135a887..9ba3a105cc57 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -427,7 +427,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed,test_baseline", - itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS, [True, False])) + #itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS, [True, False])) + itertools.product([192], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS, [True, False])) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fecefea8b062..ae210a6c2bb8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -521,7 +521,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), ) - if (use_int8_w8a16 or use_int4_w4a16) and \ + if use_dg: + # Note: we do not apply weights here since it requires + # resizing the output. + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (A, A_scale), (B, B_scale), C, expert_ids) + + elif (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 @@ -772,19 +778,12 @@ def get_default_config( # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # BLOCK_SIZE_K must be divisible by block_shape[1] config = { - "BLOCK_SIZE_M": - 64 if not use_deep_gemm else - dg.get_m_alignment_for_contiguous_layout(), - "BLOCK_SIZE_N": - block_shape[0], - "BLOCK_SIZE_K": - block_shape[1], - "GROUP_SIZE_M": - 32, - "num_warps": - 4, - "num_stages": - 3, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, } elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None: # moe wna16 kernels @@ -814,17 +813,11 @@ def get_default_config( "GROUP_SIZE_M": 1, } else: - dg_config = use_deep_gemm and dtype == "fp8_w8a8" config = { - "BLOCK_SIZE_M": - 64 - if not dg_config else dg.get_m_alignment_for_contiguous_layout(), - "BLOCK_SIZE_N": - 64 if not dg_config else 128, - "BLOCK_SIZE_K": - 32 if not dg_config else 128, - "GROUP_SIZE_M": - 8, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, } return config @@ -837,7 +830,6 @@ def try_get_optimal_moe_config( M: int, is_marlin: bool = False, block_shape: Optional[List[int]] = None, - use_deep_gemm: bool = False, ): from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() @@ -859,7 +851,7 @@ def try_get_optimal_moe_config( else: # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, - is_marlin, block_shape, use_deep_gemm) + is_marlin, block_shape) return config @@ -1298,7 +1290,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, top_k_num, config_dtype, block_shape=block_shape, - use_deep_gemm=has_deep_gemm and allow_deep_gemm, # hacky ) config = get_config_func(M) @@ -1337,13 +1328,15 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, use_fp8_w8a8) - block_m = config['BLOCK_SIZE_M'] - assert not use_dg or block_m == 128 + config_block_m = config['BLOCK_SIZE_M'] + block_m = config_block_m if not use_dg else dg.get_m_alignment_for_contiguous_layout() + + assert not use_dg or block_m == dg.get_m_alignment_for_contiguous_layout() chunked_dg = False if use_dg: - if M % 128 != 0: - CHUNK_SIZE = (M // 128) * 128 # min with env? + if M % block_m != 0: + CHUNK_SIZE = min((M //block_m) * block_m, CHUNK_SIZE) num_chunks = (num_tokens // CHUNK_SIZE) + 1 chunked_dg = num_chunks > 1 @@ -1386,9 +1379,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_chunks = (num_tokens // CHUNK_SIZE) + 1 - # TODO: modify CHUNK_SIZE to be % 128 == 0 and check if each chunk is - # valid dg. fall back to old kernel if not - for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1396,7 +1386,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.shape - skip_dg = tokens_in_chunk % 128 != 0 + skip_dg = use_dg and tokens_in_chunk % 128 != 0 #block_m != 0 if tokens_in_chunk == 0: break From e150caa1a28e1f95145219f4ba8ae31add6d64bb Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 13 Mar 2025 20:51:59 +0000 Subject: [PATCH 088/171] clean up some blocking stuff Signed-off-by: Bill Nell --- requirements/test.txt | 6 -- tests/kernels/test_block_fp8.py | 67 +++++++++---------- .../layers/fused_moe/fused_moe.py | 19 +++--- 3 files changed, 42 insertions(+), 50 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index 03093e134524..2e8121e3882e 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -132,10 +132,6 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval -exceptiongroup==1.2.2 - # via - # anyio - # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -776,11 +772,9 @@ typing-extensions==4.12.2 # huggingface-hub # librosa # mistral-common - # multidict # pqdm # pydantic # pydantic-core - # rich # torch # typer tzdata==2024.2 diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 9ba3a105cc57..ec16ac30a770 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -38,7 +38,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [128, 192, 512, 1335, 2048] +M_moe_dg = [1, 128, 192, 512, 1335, 2048] N_moe = [128, 256, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] @@ -426,17 +426,16 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( - "M,N,K,E,topk,block_size,dtype,seed,test_baseline", - #itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS, [True, False])) - itertools.product([192], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS, [True, False])) + "M,N,K,E,topk,block_size,dtype,seed", + itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([192], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, - dtype, seed, test_baseline): + dtype, seed): # only aligned sizes - if ((M % 128 != 0 and not test_baseline) or N % 128 != 0 or K % 128 != 0 - or topk > E): + if (N % 128 != 0 or K % 128 != 0 or topk > E): pytest.skip( f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") @@ -487,36 +486,26 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_s = w2_sa with set_current_vllm_config(vllm_config): - if test_baseline: - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) - - out = fused_moe(a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - allow_deep_gemm=True) + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) + + if M % 128 == 0: + ref_out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, + score, topk, block_size) else: - ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, - score, topk, block_size) - - out = fused_moe(a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - allow_deep_gemm=True) + ref_out2 = None + + out = fused_moe(a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + allow_deep_gemm=True) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") @@ -525,3 +514,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.03 + + if ref_out2 is not None: + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out2.to(torch.float32))) / + torch.mean(torch.abs(ref_out2.to(torch.float32)))) + assert rel_diff < 0.03 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ae210a6c2bb8..00ad0c37caf6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -522,6 +522,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B.shape[1], META['BLOCK_SIZE_N']), ) if use_dg: + assert use_fp8_w8a8 # Note: we do not apply weights here since it requires # resizing the output. dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( @@ -772,7 +773,6 @@ def get_default_config( dtype: Optional[str], is_marlin: bool, block_shape: Optional[List[int]] = None, - use_deep_gemm: bool = False, ) -> Dict[str, int]: if dtype == "fp8_w8a8" and block_shape is not None: # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] @@ -830,6 +830,7 @@ def try_get_optimal_moe_config( M: int, is_marlin: bool = False, block_shape: Optional[List[int]] = None, + use_deep_gemm: bool = False, ): from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() @@ -852,6 +853,12 @@ def try_get_optimal_moe_config( # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape) + + + # Remove this + if use_deep_gemm: + config['BLOCK_SIZE_M'] = 128 + return config @@ -1325,12 +1332,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, - use_fp8_w8a8) - - config_block_m = config['BLOCK_SIZE_M'] - block_m = config_block_m if not use_dg else dg.get_m_alignment_for_contiguous_layout() - + block_m = config['BLOCK_SIZE_M'] assert not use_dg or block_m == dg.get_m_alignment_for_contiguous_layout() chunked_dg = False @@ -1379,6 +1381,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_chunks = (num_tokens // CHUNK_SIZE) + 1 + for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1386,7 +1389,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.shape - skip_dg = use_dg and tokens_in_chunk % 128 != 0 #block_m != 0 + skip_dg = use_dg and tokens_in_chunk % block_m != 0 if tokens_in_chunk == 0: break From f4d5441475d79fbcdaa860294009d565cebe8797 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 13 Mar 2025 20:58:38 +0000 Subject: [PATCH 089/171] clean up some blocking stuff Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 00ad0c37caf6..ef5436132a07 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -854,10 +854,9 @@ def try_get_optimal_moe_config( config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape) - - # Remove this + # Try to remove this if use_deep_gemm: - config['BLOCK_SIZE_M'] = 128 + config['BLOCK_SIZE_M'] = dg.get_m_alignment_for_contiguous_layout() return config @@ -1389,11 +1388,11 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.shape - skip_dg = use_dg and tokens_in_chunk % block_m != 0 - if tokens_in_chunk == 0: break + skip_dg = use_dg and tokens_in_chunk % block_m != 0 + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] From 3b5f459e1e3930421fc9e174578d4bdab4d640db Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 14 Mar 2025 23:40:03 +0000 Subject: [PATCH 090/171] tweaks Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 18 ++++++++++-------- .../layers/fused_moe/fused_moe.py | 3 +-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index ec16ac30a770..11d35ec345a4 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -427,17 +427,18 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([192], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS)) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): # only aligned sizes - if (N % 128 != 0 or K % 128 != 0 or topk > E): + if (N % 128 != 0 or K % 128 != 0 or topk > E or block_size != [128, 128]): pytest.skip( - f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}, " + f"block_size={block_size}") vllm_config = VllmConfig() @@ -486,12 +487,13 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_s = w2_sa with set_current_vllm_config(vllm_config): - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) if M % 128 == 0: - ref_out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, - score, topk, block_size) + ref_out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, + w2_s, score, topk, + block_size) else: ref_out2 = None diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ef5436132a07..8c1dc35e276d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1337,7 +1337,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, chunked_dg = False if use_dg: if M % block_m != 0: - CHUNK_SIZE = min((M //block_m) * block_m, CHUNK_SIZE) + CHUNK_SIZE = min((M // block_m) * block_m, CHUNK_SIZE) num_chunks = (num_tokens // CHUNK_SIZE) + 1 chunked_dg = num_chunks > 1 @@ -1380,7 +1380,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_chunks = (num_tokens // CHUNK_SIZE) + 1 - for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, From d8771fabfeb9af332e20a874e8b621406e78bb8c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 15 Mar 2025 00:02:28 +0000 Subject: [PATCH 091/171] fix rebase Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 8c1dc35e276d..9ea60e481341 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1311,9 +1311,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K) # This needs separate memory since it's used concurrently with cache1 - intermediate_cache2 = torch.empty((M * top_k_num, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) + #intermediate_cache2 = torch.empty((M * top_k_num, N // 2), + # device=hidden_states.device, + # dtype=hidden_states.dtype) # XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX @@ -1358,25 +1358,31 @@ def fused_experts_impl(hidden_states: torch.Tensor, new_M = ((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m - intermediate_cache1 = torch.empty((new_M, N), - device=hidden_states.device, - dtype=hidden_states.dtype) + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty(new_M * max(N, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + + intermediate_cache1 = cache13[:(new_M * N)].view(new_M, N) intermediate_cache2 = torch.empty((new_M, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((new_M, w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) + intermediate_cache3 = cache13[:(new_M * w2.shape[1])].view(new_M, w2.shape[1]) else: - intermediate_cache1 = torch.empty((M, top_k_num, N), - device=hidden_states.device, - dtype=hidden_states.dtype) + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + + intermediate_cache1 = cache13[:M * top_k_num * N].view( + (M, topk_ids.shape[1], N)) intermediate_cache2 = torch.empty((M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) + intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view( + (M, topk_ids.shape[1], w2.shape[1])) num_chunks = (num_tokens // CHUNK_SIZE) + 1 From 00ad23a83fdd22023d5b5f16f6fa6e1ee9e5a235 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 17 Mar 2025 16:15:15 +0000 Subject: [PATCH 092/171] rebase Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 3 +-- .../layers/fused_moe/fused_moe.py | 20 ++----------------- 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 11d35ec345a4..d787eb0044a0 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -2,7 +2,6 @@ # Adapted from https://github.com/sgl-project/sglang/pull/2575 import itertools -from typing import Tuple import pytest import torch @@ -288,7 +287,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): def per_block_cast_to_fp8( x: torch.Tensor, - block_size_n: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape x_padded = torch.zeros( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 9ea60e481341..60248b86a453 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1300,23 +1300,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, config = get_config_func(M) - # XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX - - # We can reuse the memory between these because by the time we need - # cache3, we're done with cache1 - cache13 = torch.empty(M * top_k_num * max(N, K), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N) - intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K) - - # This needs separate memory since it's used concurrently with cache1 - #intermediate_cache2 = torch.empty((M * top_k_num, N // 2), - # device=hidden_states.device, - # dtype=hidden_states.dtype) - - # XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX - if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 elif hidden_states.dtype == torch.float16: @@ -1368,7 +1351,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, intermediate_cache2 = torch.empty((new_M, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = cache13[:(new_M * w2.shape[1])].view(new_M, w2.shape[1]) + intermediate_cache3 = cache13[:(new_M * w2.shape[1])].view( + new_M, w2.shape[1]) else: # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 From 833182f47d64de5354032de06bfe690cc866cc01 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 21 Mar 2025 21:52:56 +0000 Subject: [PATCH 093/171] refactoring + minor perf improvements Signed-off-by: Bill Nell --- benchmarks/kernels/benchmark_moe.py | 26 ++-- tests/kernels/test_block_fp8.py | 118 ++++++++++++------ .../layers/fused_moe/fused_moe.py | 32 ++--- 3 files changed, 111 insertions(+), 65 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index a274537a6751..a3092c44e332 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -30,18 +30,20 @@ class BenchmarkConfig(TypedDict): num_stages: int -def benchmark_config(config: BenchmarkConfig, - num_tokens: int, - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - block_quant_shape: List[int] = None, - use_deep_gemm: bool = False) -> float: +def benchmark_config( + config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + block_quant_shape: List[int] = None, + use_deep_gemm: bool = False +) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) if use_int8_w8a16: diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index d787eb0044a0..f75f2f2f5f5f 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -228,6 +228,9 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): + if topk > E: + pytest.skip(f"Skipping test; topk={K} > E={E}") + torch.manual_seed(seed) factor_for_scale = 1e-2 fp8_info = torch.finfo(torch.float8_e4m3fn) @@ -276,8 +279,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) - print(f"{out.sum()=}") - print(f"{ref_out.sum()=}") + #print(f"{out.sum()=}") + #print(f"{ref_out.sum()=}") rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / @@ -348,21 +351,16 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): def fp8_perm(m, idx): - return m.view(dtype=torch.uint8)[idx, ...].view(dtype=torch.float8_e4m3fn) + if m.dtype == torch.float8_e4m3fn: + return m.view(dtype=torch.uint8)[idx, + ...].view(dtype=torch.float8_e4m3fn) + else: + return m[idx, ...] -def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, - block_shape): - """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" - num_groups = w1.shape[0] +def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): M, K = a.shape - topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) - - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - - _, block_k = block_shape[0], block_shape[1] - sorted_token_ids, m_indices, num_pad = moe_align_block_size( topk_ids, block_m, num_groups, None) @@ -381,19 +379,54 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, assert sorted_token_ids[sorted_token_ids >= num_tokens].sum() == 0 - inv_perm = torch.argsort(sorted_token_ids) + #print(f"sti {sorted_token_ids}") + + inv_perm = torch.argsort(sorted_token_ids)[:M*topk] + + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + a = fp8_perm(a, sorted_token_ids) + + if a_s is not None: + a_s = a_s.view(M, -1, K // 128).repeat(1, topk, + 1).reshape(-1, K // 128) + a_s = a_s[sorted_token_ids] + + return a, a_s, m_indices, inv_perm + + +def test_moe_unpermute(out, inv_perm, m_indices, topk, num_groups, M, K, + topk_weight, topk_ids): + # TODO use moe_sum? + out = out[inv_perm, ...] + tmp_out = out.view(-1, topk, K) + return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + - a_q, a_s = per_token_group_quant_fp8(a, block_m) +def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, + block_shape): + """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" + num_groups = w1.shape[0] + M, K = a.shape - # Replicate activations and scales - a_q = a_q.view(a_q.shape[0], -1, - a_q.shape[1]).repeat(1, topk, 1).reshape(-1, a_q.shape[1]) - a_s = a_s.view(a_s.shape[0], -1, - a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) + topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) + + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + + _, block_k = block_shape[0], block_shape[1] - # Permute activations according to sorted token ids - a_q = fp8_perm(a_q, sorted_token_ids) - a_s = a_s[sorted_token_ids] + if False: + # quantize before permute + a_q, a_s = per_token_group_quant_fp8(a, block_m) + a_q, a_s, m_indices, inv_perm = test_moe_permute( + a_q, a_s, topk_ids, num_groups, topk, block_m) + else: + # quantize after permute + a_q, a_s, m_indices, inv_perm = test_moe_permute( + a, None, topk_ids, num_groups, topk, block_m) + a_q, a_s = per_token_group_quant_fp8(a_q, block_m) + + # Fix this assert + #assert a_s.shape[1] == K // 128 and a_q.shape[0] == a_s.shape[0] == M * topk inter_out = torch.zeros((a_q.shape[0], w1[0].shape[0]), dtype=torch.bfloat16, @@ -413,13 +446,22 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - out = out[inv_perm, ...] + if True: + final_out = test_moe_unpermute(out, inv_perm, m_indices, topk, + num_groups, M, K, topk_weight, topk_ids) + else: + m_indices = torch.arange(0, + M * (topk + 1), + block_m, + dtype=torch.int, + device=out.device) - tmp_out = out[:(M * topk), ...].view(-1, topk, w2.shape[1]) + print(f"inv_perm {inv_perm}") + print(f"inv_perm[:m*topk] {inv_perm[:M*topk]}") - final_out = (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - # TODO use moe_sum? + final_out = test_moe_unpermute_op(out, inv_perm, m_indices, topk, + num_groups, M, K, topk_weight, + topk_ids) return final_out @@ -489,13 +531,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) - if M % 128 == 0: - ref_out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, - w2_s, score, topk, - block_size) - else: - ref_out2 = None - out = fused_moe(a, w1, w2, @@ -508,6 +543,13 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, allow_deep_gemm=True) + if M % 128 == 0: + out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, + w2_s, score, topk, + block_size) + else: + out2 = None + #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") @@ -516,8 +558,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.03 - if ref_out2 is not None: + if out2 is not None: rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out2.to(torch.float32))) / - torch.mean(torch.abs(ref_out2.to(torch.float32)))) + torch.abs(ref_out.to(torch.float32) - out2.to(torch.float32))) / + torch.mean(torch.abs(out2.to(torch.float32)))) assert rel_diff < 0.03 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 60248b86a453..65c29e588373 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -903,10 +903,10 @@ def fused_topk( topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) + token_expert_indices = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. @@ -1331,6 +1331,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, # We attempt to do this offline in Fp8MoEMethod, in which case these # calls will be nops. Otherwise, they'll be performed every time the # layer is executed. + print(f"SHAPES {w1_scale.shape}, {w2_scale.shape}") w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() @@ -1454,19 +1455,20 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) + # is this correct in the loop? TODO: fold in moe_sum? if use_dg and not skip_dg: - assert inv_perm is not None - M = curr_topk_weights.shape[0] - out_C = intermediate_cache3[inv_perm, ...] - out_C = out_C[:(M * top_k_num), ...] - out_C = out_C.view(-1, top_k_num, w2.shape[1]) - out_C.mul_(curr_topk_weights.view(M, -1, 1)) - tmp_cache3 = out_C + _moe_unpermute(out_hidden_states[begin_chunk_idx:end_chunk_idx], + intermediate_cache3, + inv_perm, + expert_ids, + top_k_num, + global_num_experts, + w2.shape[1], + curr_topk_weights, + curr_topk_ids) else: - tmp_cache3 = intermediate_cache3.view(*intermediate_cache3.shape) - - ops.moe_sum(tmp_cache3, - out_hidden_states[begin_chunk_idx:end_chunk_idx]) + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states From 29add300add20b8f01f396614df61291fbfa6e89 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 22 Mar 2025 03:57:02 +0000 Subject: [PATCH 094/171] refactoring + perf tweaks Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 65c29e588373..372a0463929f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1319,6 +1319,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, chunked_dg = False if use_dg: + #print("USE_DG!") if M % block_m != 0: CHUNK_SIZE = min((M // block_m) * block_m, CHUNK_SIZE) @@ -1331,7 +1332,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, # We attempt to do this offline in Fp8MoEMethod, in which case these # calls will be nops. Otherwise, they'll be performed every time the # layer is executed. - print(f"SHAPES {w1_scale.shape}, {w2_scale.shape}") + #print(f"SHAPES {w1_scale.shape}, {w2_scale.shape}") w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() @@ -1355,6 +1356,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, intermediate_cache3 = cache13[:(new_M * w2.shape[1])].view( new_M, w2.shape[1]) else: + #print(f"TRITON {allow_deep_gemm}") # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]), From b1f5fcff7608633455e2839386b32335bc2aa358 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 24 Mar 2025 15:26:41 +0000 Subject: [PATCH 095/171] remove debugging cruft Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 372a0463929f..551992b37dcb 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1319,7 +1319,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, chunked_dg = False if use_dg: - #print("USE_DG!") if M % block_m != 0: CHUNK_SIZE = min((M // block_m) * block_m, CHUNK_SIZE) @@ -1332,7 +1331,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, # We attempt to do this offline in Fp8MoEMethod, in which case these # calls will be nops. Otherwise, they'll be performed every time the # layer is executed. - #print(f"SHAPES {w1_scale.shape}, {w2_scale.shape}") w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() @@ -1356,7 +1354,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, intermediate_cache3 = cache13[:(new_M * w2.shape[1])].view( new_M, w2.shape[1]) else: - #print(f"TRITON {allow_deep_gemm}") # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]), From 2e196226615b8c2c9d40bc602beca6b1e147f229 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 24 Mar 2025 22:28:40 +0000 Subject: [PATCH 096/171] cache resize refactoring Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 15 ++--- .../layers/fused_moe/fused_moe.py | 58 +++++++------------ 2 files changed, 26 insertions(+), 47 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index f75f2f2f5f5f..188ea9723021 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -396,7 +396,6 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): def test_moe_unpermute(out, inv_perm, m_indices, topk, num_groups, M, K, topk_weight, topk_ids): - # TODO use moe_sum? out = out[inv_perm, ...] tmp_out = out.view(-1, topk, K) return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) @@ -414,16 +413,10 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, _, block_k = block_shape[0], block_shape[1] - if False: - # quantize before permute - a_q, a_s = per_token_group_quant_fp8(a, block_m) - a_q, a_s, m_indices, inv_perm = test_moe_permute( - a_q, a_s, topk_ids, num_groups, topk, block_m) - else: - # quantize after permute - a_q, a_s, m_indices, inv_perm = test_moe_permute( - a, None, topk_ids, num_groups, topk, block_m) - a_q, a_s = per_token_group_quant_fp8(a_q, block_m) + a_q, a_s = per_token_group_quant_fp8(a, block_m) + + a_q, a_s, m_indices, inv_perm = test_moe_permute( + a_q, a_s, topk_ids, num_groups, topk, block_m) # Fix this assert #assert a_s.shape[1] == K // 128 and a_q.shape[0] == a_s.shape[0] == M * topk diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 551992b37dcb..63e9bd5e5425 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -3,6 +3,7 @@ import functools import json import os +from math import prod from typing import Any, Callable, Dict, List, Optional, Tuple import torch @@ -1317,14 +1318,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_m = config['BLOCK_SIZE_M'] assert not use_dg or block_m == dg.get_m_alignment_for_contiguous_layout() - chunked_dg = False if use_dg: if M % block_m != 0: CHUNK_SIZE = min((M // block_m) * block_m, CHUNK_SIZE) - num_chunks = (num_tokens // CHUNK_SIZE) + 1 - chunked_dg = num_chunks > 1 - assert w1_scale is not None assert w2_scale is not None @@ -1334,41 +1331,30 @@ def fused_experts_impl(hidden_states: torch.Tensor, w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - # TODO: computing new_M could be smarter - sorted_token_ids, _, _ = (moe_align_block_size(topk_ids, block_m, - global_num_experts, - expert_map)) - - new_M = ((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m + M_sum = topk_ids.numel() + global_num_experts * (block_m - 1) + M_sum = ((M_sum + block_m - 1) // block_m) * block_m - # We can reuse the memory between these because by the time we need - # cache3, we're done with cache1 - cache13 = torch.empty(new_M * max(N, w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) - - intermediate_cache1 = cache13[:(new_M * N)].view(new_M, N) - intermediate_cache2 = torch.empty((new_M, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache3 = cache13[:(new_M * w2.shape[1])].view( - new_M, w2.shape[1]) + cache1_view = (M_sum, N) + cache3_view = (M_sum, K) else: - # We can reuse the memory between these because by the time we need - # cache3, we're done with cache1 - cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) + M_sum = M * top_k_num + cache1_view = (M, top_k_num, N) + cache3_view = (M, top_k_num, K) + + num_chunks = (num_tokens // CHUNK_SIZE) + 1 + + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty(M_sum * max(N, K), + device=hidden_states.device, + dtype=hidden_states.dtype) - intermediate_cache1 = cache13[:M * top_k_num * N].view( - (M, topk_ids.shape[1], N)) - intermediate_cache2 = torch.empty((M * top_k_num, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view( - (M, topk_ids.shape[1], w2.shape[1])) + intermediate_cache1 = cache13[:M_sum * N].view(*cache1_view) + intermediate_cache2 = torch.empty((M_sum, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = cache13[:M_sum * K].view(*cache3_view) - num_chunks = (num_tokens // CHUNK_SIZE) + 1 for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, @@ -1462,7 +1448,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, expert_ids, top_k_num, global_num_experts, - w2.shape[1], + K, curr_topk_weights, curr_topk_ids) else: From 5d970220ce458aecfdb03dff19c26bb4035b3202 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 25 Mar 2025 14:23:17 +0000 Subject: [PATCH 097/171] cleanups Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 42 +++------------ .../layers/fused_moe/fused_moe.py | 51 ++++++++++--------- 2 files changed, 35 insertions(+), 58 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 188ea9723021..30ab50ddf798 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -229,7 +229,7 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): @torch.inference_mode() def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): if topk > E: - pytest.skip(f"Skipping test; topk={K} > E={E}") + pytest.skip(f"Skipping test; topk={topk} > E={E}") torch.manual_seed(seed) factor_for_scale = 1e-2 @@ -351,7 +351,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): def fp8_perm(m, idx): - if m.dtype == torch.float8_e4m3fn: + if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: return m.view(dtype=torch.uint8)[idx, ...].view(dtype=torch.float8_e4m3fn) else: @@ -379,8 +379,6 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): assert sorted_token_ids[sorted_token_ids >= num_tokens].sum() == 0 - #print(f"sti {sorted_token_ids}") - inv_perm = torch.argsort(sorted_token_ids)[:M*topk] a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) @@ -418,9 +416,6 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_q, a_s, m_indices, inv_perm = test_moe_permute( a_q, a_s, topk_ids, num_groups, topk, block_m) - # Fix this assert - #assert a_s.shape[1] == K // 128 and a_q.shape[0] == a_s.shape[0] == M * topk - inter_out = torch.zeros((a_q.shape[0], w1[0].shape[0]), dtype=torch.bfloat16, device=a.device) @@ -439,22 +434,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - if True: - final_out = test_moe_unpermute(out, inv_perm, m_indices, topk, - num_groups, M, K, topk_weight, topk_ids) - else: - m_indices = torch.arange(0, - M * (topk + 1), - block_m, - dtype=torch.int, - device=out.device) - - print(f"inv_perm {inv_perm}") - print(f"inv_perm[:m*topk] {inv_perm[:M*topk]}") - - final_out = test_moe_unpermute_op(out, inv_perm, m_indices, topk, - num_groups, M, K, topk_weight, - topk_ids) + final_out = test_moe_unpermute(out, inv_perm, m_indices, topk, + num_groups, M, K, topk_weight, topk_ids) return final_out @@ -502,6 +483,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() + w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() + assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] @@ -509,17 +493,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - w1_sa = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() - w2_sa = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() - - # TODO: move size alignment further up when setting up all shapes - if w1_sa.shape != w1_s.shape or w2_sa.shape != w2_s.shape: - print("UNALIGNED") - pytest.skip("UNALIGNED") - - w1_s = w1_sa - w2_s = w2_sa - + # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 63e9bd5e5425..133478f46af1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -524,8 +524,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, if use_dg: assert use_fp8_w8a8 - # Note: we do not apply weights here since it requires - # resizing the output. + # Note: we never apply the topk_weights here since it requires + # unpermuting and resizing the output. This goes against the + # existing interface as the `mul_routed_weight` argument is + # ignored. The weights are applied in _moe_unpermute. dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (A, A_scale), (B, B_scale), C, expert_ids) @@ -855,7 +857,7 @@ def try_get_optimal_moe_config( config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape) - # Try to remove this + # Enforce DeepGemm M blocking no matter what the config says. if use_deep_gemm: config['BLOCK_SIZE_M'] = dg.get_m_alignment_for_contiguous_layout() @@ -904,10 +906,10 @@ def fused_topk( topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indices = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. @@ -1319,15 +1321,18 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert not use_dg or block_m == dg.get_m_alignment_for_contiguous_layout() if use_dg: + # If M is not divisible by the block size we run the largest + # chunk we can using DeepGemm, the remainder is handed off to + # the Triton kernels. if M % block_m != 0: CHUNK_SIZE = min((M // block_m) * block_m, CHUNK_SIZE) assert w1_scale is not None assert w2_scale is not None - # We attempt to do this offline in Fp8MoEMethod, in which case these - # calls will be nops. Otherwise, they'll be performed every time the - # layer is executed. + # We attempt to transpose and align offline in Fp8MoEMethod, in which + # case these calls will be nops. Otherwise, they'll be performed every + # time the layer is executed. w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() @@ -1366,6 +1371,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, if tokens_in_chunk == 0: break + # Even if we are using DeepGemm, we must defer any chunks + # that are not blocked to Triton. skip_dg = use_dg and tokens_in_chunk % block_m != 0 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] @@ -1440,20 +1447,16 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - # is this correct in the loop? TODO: fold in moe_sum? - if use_dg and not skip_dg: - _moe_unpermute(out_hidden_states[begin_chunk_idx:end_chunk_idx], - intermediate_cache3, - inv_perm, - expert_ids, - top_k_num, - global_num_experts, - K, - curr_topk_weights, - curr_topk_ids) - else: - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), - out_hidden_states[begin_chunk_idx:end_chunk_idx]) + _moe_unpermute_and_reduce(out_hidden_states[begin_chunk_idx:end_chunk_idx], + intermediate_cache3.view(*intermediate_cache3.shape), + inv_perm, + expert_ids, + top_k_num, + global_num_experts, + K, + curr_topk_weights, + curr_topk_ids, + use_dg and not skip_dg) return out_hidden_states From 0c343cf79470ed819326706b3689c1f534673c9c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 25 Mar 2025 14:34:43 +0000 Subject: [PATCH 098/171] format Signed-off-by: Bill Nell --- benchmarks/kernels/benchmark_moe.py | 26 ++++++++-------- tests/kernels/test_block_fp8.py | 15 +++++----- .../layers/fused_moe/fused_moe.py | 30 ++++++++++++------- 3 files changed, 38 insertions(+), 33 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index a3092c44e332..a274537a6751 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -30,20 +30,18 @@ class BenchmarkConfig(TypedDict): num_stages: int -def benchmark_config( - config: BenchmarkConfig, - num_tokens: int, - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - block_quant_shape: List[int] = None, - use_deep_gemm: bool = False -) -> float: +def benchmark_config(config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + block_quant_shape: List[int] = None, + use_deep_gemm: bool = False) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) if use_int8_w8a16: diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 30ab50ddf798..88e7e2bcdbaa 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -379,7 +379,7 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): assert sorted_token_ids[sorted_token_ids >= num_tokens].sum() == 0 - inv_perm = torch.argsort(sorted_token_ids)[:M*topk] + inv_perm = torch.argsort(sorted_token_ids)[:M * topk] a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) a = fp8_perm(a, sorted_token_ids) @@ -413,8 +413,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_q, a_s = per_token_group_quant_fp8(a, block_m) - a_q, a_s, m_indices, inv_perm = test_moe_permute( - a_q, a_s, topk_ids, num_groups, topk, block_m) + a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids, + num_groups, topk, block_m) inter_out = torch.zeros((a_q.shape[0], w1[0].shape[0]), dtype=torch.bfloat16, @@ -434,8 +434,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - final_out = test_moe_unpermute(out, inv_perm, m_indices, topk, - num_groups, M, K, topk_weight, topk_ids) + final_out = test_moe_unpermute(out, inv_perm, m_indices, topk, num_groups, + M, K, topk_weight, topk_ids) return final_out @@ -511,9 +511,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, allow_deep_gemm=True) if M % 128 == 0: - out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, - w2_s, score, topk, - block_size) + out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, + score, topk, block_size) else: out2 = None diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 133478f46af1..24caaf15eef1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -511,6 +511,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor, M = A.shape[0] num_tokens = M * top_k + if use_fp8_w8a8: + assert B_scale is not None + assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) + == B_scale.shape[-2]) + assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) + == B_scale.shape[-1]) + + elif use_int8_w8a16 or use_int4_w4a16: + assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 + else: + assert A_scale is None + assert B_scale is None + EM = sorted_token_ids.shape[0] if A.shape[0] < config["BLOCK_SIZE_M"]: # optimize for small batch_size. @@ -1360,7 +1374,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, dtype=hidden_states.dtype) intermediate_cache3 = cache13[:M_sum * K].view(*cache3_view) - for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1447,16 +1460,11 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - _moe_unpermute_and_reduce(out_hidden_states[begin_chunk_idx:end_chunk_idx], - intermediate_cache3.view(*intermediate_cache3.shape), - inv_perm, - expert_ids, - top_k_num, - global_num_experts, - K, - curr_topk_weights, - curr_topk_ids, - use_dg and not skip_dg) + _moe_unpermute_and_reduce( + out_hidden_states[begin_chunk_idx:end_chunk_idx], + intermediate_cache3.view(*intermediate_cache3.shape), inv_perm, + expert_ids, top_k_num, global_num_experts, K, curr_topk_weights, + curr_topk_ids, use_dg and not skip_dg) return out_hidden_states From f60b4b3fd216bfa603de0a7d1c89d71e45442e35 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 25 Mar 2025 16:36:14 +0000 Subject: [PATCH 099/171] revert test.txt, fix mypy errors Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 24caaf15eef1..36db8613cdd0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1334,6 +1334,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_m = config['BLOCK_SIZE_M'] assert not use_dg or block_m == dg.get_m_alignment_for_contiguous_layout() + cache1_view: Tuple[int, ...] = () + cache2_view: Tuple[int, ...] = () + cache3_view: Tuple[int, ...] = () + if use_dg: # If M is not divisible by the block size we run the largest # chunk we can using DeepGemm, the remainder is handed off to From 856046b2e8cd9701c070352f1dad1b948ae0705b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 26 Mar 2025 22:15:33 +0000 Subject: [PATCH 100/171] review comments Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 36db8613cdd0..5212ea177f59 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -23,7 +23,7 @@ from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, round_up from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled @@ -1355,7 +1355,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() M_sum = topk_ids.numel() + global_num_experts * (block_m - 1) - M_sum = ((M_sum + block_m - 1) // block_m) * block_m + M_sum = round_up(M_sum, block_m) cache1_view = (M_sum, N) cache3_view = (M_sum, K) From c7f3ddb72639161d865e9fc0de8fef9f75ea7c03 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 27 Mar 2025 02:21:17 +0000 Subject: [PATCH 101/171] review comments Signed-off-by: Bill Nell --- .../quantization/compressed_tensors/compressed_tensors_moe.py | 2 -- vllm/model_executor/layers/quantization/fp8.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 751f1aaf2c38..721e36af2b28 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -272,8 +272,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: from vllm.model_executor.layers.fused_moe import fused_experts self.fused_experts_func = fused_experts - # TODO: do we need to do deep gemm alignment here? - def apply( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e184c3173a4c..5e0386ea6bde 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -443,7 +443,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None - self.allow_deep_gemm = use_deep_gemm + self.allow_deep_gemm = allow_deep_gemm # Check for DeepGemm support. self.allow_deep_gemm = False From f653358f93a3e4ca7416c677b3e8074b07fc3407 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 27 Mar 2025 03:14:23 +0000 Subject: [PATCH 102/171] clean up use_dg flags Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 20 ++++++------------- .../layers/fused_moe/fused_moe.py | 14 +++++++------ 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 88e7e2bcdbaa..fdb5f4c3a5bd 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -495,8 +495,12 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) + if M % 128 == 0: + ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, + score, topk, block_size) + else: + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) out = fused_moe(a, w1, @@ -510,12 +514,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, allow_deep_gemm=True) - if M % 128 == 0: - out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, - score, topk, block_size) - else: - out2 = None - #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") @@ -523,9 +521,3 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.03 - - if out2 is not None: - rel_diff = (torch.mean( - torch.abs(ref_out.to(torch.float32) - out2.to(torch.float32))) / - torch.mean(torch.abs(out2.to(torch.float32)))) - assert rel_diff < 0.03 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 5212ea177f59..31733cce6eda 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1366,8 +1366,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_chunks = (num_tokens // CHUNK_SIZE) + 1 - # We can reuse the memory between these because by the time we need - # cache3, we're done with cache1 + # We can reuse the memory between cache1 and cache3 because by the time + # we need cache3, we're done with cache1 cache13 = torch.empty(M_sum * max(N, K), device=hidden_states.device, dtype=hidden_states.dtype) @@ -1378,6 +1378,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, dtype=hidden_states.dtype) intermediate_cache3 = cache13[:M_sum * K].view(*cache3_view) + needs_fp8_quantization = use_fp8_w8a8 or use_dg + for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1388,9 +1390,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, if tokens_in_chunk == 0: break - # Even if we are using DeepGemm, we must defer any chunks - # that are not blocked to Triton. - skip_dg = use_dg and tokens_in_chunk % block_m != 0 + # If we are using DeepGemm, only operate on chunks that are + # blocked, otherwise defer to Triton. + use_dg_for_chunk = use_dg and tokens_in_chunk % block_m == 0 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] @@ -1468,7 +1470,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, out_hidden_states[begin_chunk_idx:end_chunk_idx], intermediate_cache3.view(*intermediate_cache3.shape), inv_perm, expert_ids, top_k_num, global_num_experts, K, curr_topk_weights, - curr_topk_ids, use_dg and not skip_dg) + curr_topk_ids, use_dg_for_chunk) return out_hidden_states From 9391c66193974621bd625b62b9e890846f6afdcd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 27 Mar 2025 15:24:26 +0000 Subject: [PATCH 103/171] remove check for aligned M Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 31733cce6eda..0199347ab3e8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1339,12 +1339,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, cache3_view: Tuple[int, ...] = () if use_dg: - # If M is not divisible by the block size we run the largest - # chunk we can using DeepGemm, the remainder is handed off to - # the Triton kernels. - if M % block_m != 0: - CHUNK_SIZE = min((M // block_m) * block_m, CHUNK_SIZE) - assert w1_scale is not None assert w2_scale is not None @@ -1390,10 +1384,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, if tokens_in_chunk == 0: break - # If we are using DeepGemm, only operate on chunks that are - # blocked, otherwise defer to Triton. - use_dg_for_chunk = use_dg and tokens_in_chunk % block_m == 0 - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] @@ -1470,7 +1460,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, out_hidden_states[begin_chunk_idx:end_chunk_idx], intermediate_cache3.view(*intermediate_cache3.shape), inv_perm, expert_ids, top_k_num, global_num_experts, K, curr_topk_weights, - curr_topk_ids, use_dg_for_chunk) + curr_topk_ids, use_dg) return out_hidden_states From 2351edfecfe7af5a31c7104082241e2e67cb61f4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 28 Mar 2025 18:31:23 +0000 Subject: [PATCH 104/171] rebase + clean up test Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 35 ++++++++++++--------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index fdb5f4c3a5bd..ce80abb44997 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform +from vllm.utils import round_up dg_available = False try: @@ -352,8 +353,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): def fp8_perm(m, idx): if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: - return m.view(dtype=torch.uint8)[idx, - ...].view(dtype=torch.float8_e4m3fn) + return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) else: return m[idx, ...] @@ -366,34 +366,26 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): num_tokens = topk * M - pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * - block_m) - sorted_token_ids.numel() + pad_size = (round_up(sorted_token_ids.numel(), block_m) - + sorted_token_ids.numel()) if pad_size > 0: sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", num_tokens) sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) - - assert sorted_token_ids[sorted_token_ids >= num_tokens].sum() == 0 - inv_perm = torch.argsort(sorted_token_ids)[:M * topk] - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - a = fp8_perm(a, sorted_token_ids) - + a = fp8_perm(a, sorted_token_ids // topk) if a_s is not None: - a_s = a_s.view(M, -1, K // 128).repeat(1, topk, - 1).reshape(-1, K // 128) - a_s = a_s[sorted_token_ids] + a_s = a_s[sorted_token_ids // topk] return a, a_s, m_indices, inv_perm -def test_moe_unpermute(out, inv_perm, m_indices, topk, num_groups, M, K, - topk_weight, topk_ids): +def test_moe_unpermute(out, inv_perm, topk, K, topk_weight): + M = topk_weight.shape[0] out = out[inv_perm, ...] tmp_out = out.view(-1, topk, K) return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) @@ -404,6 +396,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" num_groups = w1.shape[0] M, K = a.shape + N = w2.shape[-1] topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) @@ -416,7 +409,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids, num_groups, topk, block_m) - inter_out = torch.zeros((a_q.shape[0], w1[0].shape[0]), + inter_out = torch.zeros((a_q.shape[0], N * 2), dtype=torch.bfloat16, device=a.device) @@ -426,16 +419,14 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - out = torch.zeros(act_out.shape[0], - w2.shape[1], + out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - final_out = test_moe_unpermute(out, inv_perm, m_indices, topk, num_groups, - M, K, topk_weight, topk_ids) + final_out = test_moe_unpermute(out, inv_perm, topk, K, topk_weight) return final_out @@ -495,7 +486,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): - if M % 128 == 0: + if M >= 128: ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_size) else: From d0e81cc8e8d19039de725d69140ef93f8ba824c3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 28 Mar 2025 20:32:18 +0000 Subject: [PATCH 105/171] fix format Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index ce80abb44997..89e9a073acf9 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform -from vllm.utils import round_up dg_available = False try: @@ -362,17 +361,10 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): M, K = a.shape sorted_token_ids, m_indices, num_pad = moe_align_block_size( - topk_ids, block_m, num_groups, None) + topk_ids, block_m, num_groups, None, pad_sorted_ids=True) num_tokens = topk * M - pad_size = (round_up(sorted_token_ids.numel(), block_m) - - sorted_token_ids.numel()) - if pad_size > 0: - sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, - (0, pad_size), "constant", - num_tokens) - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) inv_perm = torch.argsort(sorted_token_ids)[:M * topk] @@ -419,9 +411,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - out = torch.zeros(a_q.shape[0], K, - dtype=torch.bfloat16, - device=a.device) + out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) @@ -490,8 +480,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_size) else: - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) out = fused_moe(a, w1, From b5fb80c571c14b125cc571425e3a57ff86d638b1 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 31 Mar 2025 18:34:07 +0000 Subject: [PATCH 106/171] Clean up diff Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/cuda_graph_utils.py | 0 .../layers/fused_moe/fused_moe.py | 27 ++-------- vllm/model_executor/layers/fused_moe/layer.py | 50 ++++++++++++------- 3 files changed, 36 insertions(+), 41 deletions(-) delete mode 100644 vllm/cuda_graph_utils.py diff --git a/vllm/cuda_graph_utils.py b/vllm/cuda_graph_utils.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 0199347ab3e8..3133675a2410 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -536,16 +536,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), ) - if use_dg: - assert use_fp8_w8a8 - # Note: we never apply the topk_weights here since it requires - # unpermuting and resizing the output. This goes against the - # existing interface as the `mul_routed_weight` argument is - # ignored. The weights are applied in _moe_unpermute. - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (A, A_scale), (B, B_scale), C, expert_ids) - - elif (use_int8_w8a16 or use_int4_w4a16) and \ + if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 @@ -847,7 +838,6 @@ def try_get_optimal_moe_config( M: int, is_marlin: bool = False, block_shape: Optional[List[int]] = None, - use_deep_gemm: bool = False, ): from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() @@ -870,11 +860,6 @@ def try_get_optimal_moe_config( # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape) - - # Enforce DeepGemm M blocking no matter what the config says. - if use_deep_gemm: - config['BLOCK_SIZE_M'] = dg.get_m_alignment_for_contiguous_layout() - return config @@ -1048,14 +1033,13 @@ def inplace_fused_experts(hidden_states: torch.Tensor, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, - allow_deep_gemm: bool = False) -> None: + block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape, allow_deep_gemm) + block_shape) def inplace_fused_experts_fake( @@ -1492,7 +1476,6 @@ def fused_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, - allow_deep_gemm: bool = True, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -1526,8 +1509,8 @@ def fused_moe( Defaults to False. - global_num_experts (int): The total number of experts in the global expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert parallel shard. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 279ba2778b1f..be9b2bc4a64d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,16 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 from abc import abstractmethod +from dataclasses import dataclass from enum import Enum from typing import Callable, List, Optional, Tuple -from dataclasses import dataclass +import pplx_kernels as pplx import torch import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter -import pplx_kernels as pplx - import vllm.envs as envs from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, @@ -39,6 +38,7 @@ MOE_DP_CHUNK_SIZE = 256 + # Adapted from pplx-kernels tests/all_to_all_utils.py @dataclass class MoEConfig: @@ -56,6 +56,7 @@ class MoEConfig: out_dtype: torch.dtype = torch.bfloat16 block_size: int = 128 + class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" CHANNEL = "channel" @@ -92,9 +93,12 @@ def apply( ) -> torch.Tensor: raise NotImplementedError + +#TODO: Every change in this class is a broken hack!! @CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" + def __init__(self, moe: MoEConfig): self.all_to_all = pplx.AllToAll( max_num_tokens=MOE_DP_CHUNK_SIZE // moe.dp_size, @@ -108,7 +112,6 @@ def __init__(self, moe: MoEConfig): hidden_dim_scale_bytes=0, ) - def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -874,7 +877,7 @@ def forward(self, hidden_states: torch.Tensor, self.layer_name) def forward_impl_chunked(self, full_hidden_states: torch.Tensor, - full_router_logits: torch.Tensor): + full_router_logits: torch.Tensor): max_tokens_across_dp = get_forward_context( ).dp_metadata.max_tokens_across_dp cu_tokens_across_dp_cpu = get_forward_context( @@ -890,21 +893,23 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, num_tokens_remaining_across_dp = num_tokens_across_dp chunk_start = 0 - chunk_end = min(moe_dp_chunk_size_per_rank, full_hidden_states.shape[0]) + chunk_end = min(moe_dp_chunk_size_per_rank, + full_hidden_states.shape[0]) full_final_hidden_states = torch.empty_like(full_hidden_states) for _ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): - hidden_states = full_hidden_states[chunk_start:chunk_end,:] - router_logits = full_router_logits[chunk_start:chunk_end,:] + hidden_states = full_hidden_states[chunk_start:chunk_end, :] + router_logits = full_router_logits[chunk_start:chunk_end, :] cu_tokens_across_dp_this_iter = torch.cumsum( - num_tokens_remaining_across_dp.clamp(max=moe_dp_chunk_size_per_rank), + num_tokens_remaining_across_dp.clamp( + max=moe_dp_chunk_size_per_rank), dim=0) - hidden_states = self.naive_multicast(hidden_states, - cu_tokens_across_dp_this_iter) - router_logits = self.naive_multicast(router_logits, - cu_tokens_across_dp_this_iter) + hidden_states = self.naive_multicast( + hidden_states, cu_tokens_across_dp_this_iter) + router_logits = self.naive_multicast( + router_logits, cu_tokens_across_dp_this_iter) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -925,7 +930,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, ) if self.dp_size > 1: - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[self.dp_rank-1] + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[ + self.dp_rank - 1] end = cu_tokens_across_dp_this_iter[self.dp_rank] all_hidden_states = get_dp_group().all_reduce( @@ -934,20 +940,26 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) - full_final_hidden_states[chunk_start:chunk_end, :].copy_(final_hidden_states) + full_final_hidden_states[chunk_start:chunk_end, :].copy_( + final_hidden_states) # Update bounds - num_tokens_remaining_across_dp = torch.clamp(num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, min=0) + num_tokens_remaining_across_dp = torch.clamp( + num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, + min=0) + def update_chunk_bound(x: int): - return min(x + moe_dp_chunk_size_per_rank, full_hidden_states.shape[0]) + return min(x + moe_dp_chunk_size_per_rank, + full_hidden_states.shape[0]) + chunk_start = update_chunk_bound(chunk_start) chunk_end = update_chunk_bound(chunk_end) return full_final_hidden_states - def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None From 204c4d59a57e8ba9500f6c57f81ea290b2ecbf09 Mon Sep 17 00:00:00 2001 From: Ilya Markov Date: Tue, 1 Apr 2025 07:49:12 +0200 Subject: [PATCH 107/171] [Distributed] Add custom allreduce support for ROCM (#14125) Signed-off-by: ilmarkov Co-authored-by: ilmarkov Signed-off-by: Bill Nell --- csrc/custom_all_reduce.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 44709b459776..186abf4712fd 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -602,4 +602,4 @@ class CustomAllreduce { * template void vllm::CustomAllreduce::allreduce(cudaStream_t, half *, half *, int, int, int); */ -} // namespace vllm \ No newline at end of file +} // namespace vllm From ad77c5f01407e4e8c50a3732bfe8608663e06425 Mon Sep 17 00:00:00 2001 From: Yan Ma Date: Tue, 1 Apr 2025 13:53:37 +0800 Subject: [PATCH 108/171] [Bugfix][Model] fix mllama multi-image (#14883) Signed-off-by: yan ma Signed-off-by: Bill Nell --- vllm/model_executor/models/mllama.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 0c1d61c01f91..971a4e695dab 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1245,6 +1245,31 @@ def unpack_data(self, output_tensor[i, :t.size(0)] = t return output_tensor + def unpack_data(self, + image_data: Union[List[torch.Tensor], torch.Tensor], + padding_value=0) -> torch.Tensor: + if isinstance(image_data, torch.Tensor): + # torch.Tensor + return image_data + else: + assert isinstance( + image_data[0], + torch.Tensor), "Image data is not properly batched." + # List[torch.Tensor] + bsz = len(image_data) + max_length = max(t.size(0) for t in image_data) + trailing_dims = image_data[0].shape[1:] + for data in image_data: + cur_trailing_dims = data.shape[1:] + assert cur_trailing_dims == trailing_dims + output_tensor = torch.full((bsz, max_length, *trailing_dims), + padding_value, + dtype=image_data[0].dtype, + device=image_data[0].device) + for i, t in enumerate(image_data): + output_tensor[i, :t.size(0)] = t + return output_tensor + def _parse_and_validate_image_input(self, **kwargs: object): # tensor with the same shape will be batched together by # MultiModalKwargs.batch, so pixel_values here can be: From 84782a177b4e0709eb22885f2f0540e771499901 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 1 Apr 2025 22:06:39 +0000 Subject: [PATCH 109/171] module deepgemm moe working Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 39 +++++++++++-------- .../layers/fused_moe/modular_kernel.py | 1 + 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 89e9a073acf9..e747a96abf13 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -9,8 +9,12 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + deep_gemm_moe_fp8, + modular_deep_gemm_fused_moe_fp8) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -430,11 +434,13 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): - # only aligned sizes - if (N % 128 != 0 or K % 128 != 0 or topk > E or block_size != [128, 128]): + # only aligned sizes TODO: use _valid_deep_gemm here instead? + if (N % block_m != 0 or K % block_m != 0 or topk > E): pytest.skip( - f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}, " - f"block_size={block_size}") + f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + + if False and N <= 512: + pytest.skip("Skipping N <= 512 until performance issues solved.") vllm_config = VllmConfig() @@ -474,6 +480,13 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + if True: + dgm = modular_deep_gemm_fused_moe_fp8() + def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): + return dgm(a, w1, w2, topk_weights, topk_ids, w1_scale=w1_s, w2_scale=w2_s) + else: + deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 + # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): if M >= 128: @@ -483,17 +496,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) - out = fused_moe(a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - allow_deep_gemm=True) + topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) + + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index aab7658ae641..c386d5ec1dcd 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -126,6 +126,7 @@ def combine( experts, it will have (M, topk, K) shape. - topk_weights: The weights to be applied to the fused_experts_output. - topk_ids: The topk_ids. + - apply_router_weight_on_input: When False, apply the weights to fused_expert_output. """ raise NotImplementedError From d88baaa6bd448330ccc50469310efc491e8a1f73 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 01:13:38 +0000 Subject: [PATCH 110/171] working deep gemm, wip cutlass Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index e747a96abf13..2511d817fe72 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -10,6 +10,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + _valid_deep_gemm, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk @@ -456,6 +457,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) +# if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): +# pytest.skip( +# f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + score = torch.randn((M, E), dtype=dtype) block_n, block_k = block_size[0], block_size[1] From bf9a8337d527146b44653b7d77737f289bc2b244 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 13:49:41 +0000 Subject: [PATCH 111/171] working cutlass Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 266ba3bfa07a..38f8072ac408 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import torch From ab7ff872cbbdb5d9b5a7d48fdb55eb876c139342 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 20:33:59 +0000 Subject: [PATCH 112/171] deepgemm working again Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 38f8072ac408..4fa7139afa41 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple import torch From b1f59a888ce0b3e51ac13e9b6d408fde57caae5e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 22:14:54 +0000 Subject: [PATCH 113/171] fix inplace, format and name cleanups Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 21 +++++++++++++------ .../layers/fused_moe/deep_gemm_moe.py | 2 +- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 2511d817fe72..fafc7c18254e 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -457,9 +457,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) -# if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): -# pytest.skip( -# f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + # if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): + # pytest.skip( + # f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") score = torch.randn((M, E), dtype=dtype) @@ -487,8 +487,16 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, if True: dgm = modular_deep_gemm_fused_moe_fp8() - def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): - return dgm(a, w1, w2, topk_weights, topk_ids, w1_scale=w1_s, w2_scale=w2_s) + + def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids): + return dgm(a, + w1, + w2, + topk_weights, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s) else: deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 @@ -503,7 +511,8 @@ def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 4fa7139afa41..28050a5dd9e6 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import Any, List, Optional, Tuple +from typing import List, Optional, Tuple import torch From b9542bcadd4caef6c8248d4c0d51ac693a65dd95 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 01:18:53 +0000 Subject: [PATCH 114/171] test improvements Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 18 ++++++------------ .../layers/fused_moe/deep_gemm_moe.py | 4 ++-- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index fafc7c18254e..ed861054b4b8 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -10,9 +10,9 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm, deep_gemm_moe_fp8, - modular_deep_gemm_fused_moe_fp8) + modular_deep_gemm_fused_moe_fp8, + _valid_deep_gemm_shape) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -435,13 +435,11 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): - # only aligned sizes TODO: use _valid_deep_gemm here instead? - if (N % block_m != 0 or K % block_m != 0 or topk > E): - pytest.skip( - f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + if topk > E: + pytest.skip(f"Skipping test: topk={topk} > E={E}") - if False and N <= 512: - pytest.skip("Skipping N <= 512 until performance issues solved.") + if not _valid_deep_gemm_shape(M, N, K): + pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") vllm_config = VllmConfig() @@ -457,10 +455,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) - # if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): - # pytest.skip( - # f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") - score = torch.randn((M, E), dtype=dtype) block_n, block_k = block_size[0], block_size[1] diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 28050a5dd9e6..facbba40c3e5 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch @@ -28,7 +28,7 @@ def deep_gemm_block_shape() -> list[int]: def _valid_deep_gemm_shape(M: int, N: int, K: int): align = deep_gemm_block_shape()[0] - return align <= M and N % align == 0 and K % align == 0 + return M >= align and N % align == 0 and K % align == 0 def _valid_deep_gemm(hidden_states: torch.Tensor, From e974b59f58cc40d48555b3492378cb0c5c3a5818 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 04:41:01 +0000 Subject: [PATCH 115/171] make modular triton classes, fix edge cases Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 32 ++++++++++++++++--- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3133675a2410..af14607122fc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1207,6 +1207,30 @@ def fused_experts(hidden_states: torch.Tensor, a2_scale=a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) + elif hidden_states.shape[0] <= envs.VLLM_FUSED_MOE_CHUNK_SIZE: + fe = modular_triton_fused_moe( + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + block_shape, + ) + return fe( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, @@ -1214,6 +1238,7 @@ def fused_experts(hidden_states: torch.Tensor, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, + inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=use_fp8_w8a8, @@ -1440,11 +1465,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - _moe_unpermute_and_reduce( - out_hidden_states[begin_chunk_idx:end_chunk_idx], - intermediate_cache3.view(*intermediate_cache3.shape), inv_perm, - expert_ids, top_k_num, global_num_experts, K, curr_topk_weights, - curr_topk_ids, use_dg) + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states From 1a7bdbd4805b54e9501af92633f3e3e678f4fbd1 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 17:17:16 +0000 Subject: [PATCH 116/171] refactor dispatch/combine stuff Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 3 +++ vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index e52751eddf2c..19ca505a2561 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -9,6 +9,9 @@ from vllm.model_executor.layers.fused_moe.dispatch_combine import ( StandardDispatchCombine) from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine +) class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute): diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index facbba40c3e5..ab355c7d53e1 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -12,6 +12,13 @@ _moe_permute) from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + _moe_permute, + _moe_unpermute_and_reduce +) +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine +) from vllm.utils import round_up logger = init_logger(__name__) From ca50521212b34f1e40876c7b327c75a38ed71cad Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 19:39:35 +0000 Subject: [PATCH 117/171] initial pplx dispatch/combine class Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index ed861054b4b8..2f9315f19529 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -362,7 +362,7 @@ def fp8_perm(m, idx): return m[idx, ...] -def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): +def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): M, K = a.shape sorted_token_ids, m_indices, num_pad = moe_align_block_size( @@ -381,7 +381,7 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): return a, a_s, m_indices, inv_perm -def test_moe_unpermute(out, inv_perm, topk, K, topk_weight): +def _moe_unpermute(out, inv_perm, topk, K, topk_weight): M = topk_weight.shape[0] out = out[inv_perm, ...] tmp_out = out.view(-1, topk, K) @@ -403,8 +403,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_q, a_s = per_token_group_quant_fp8(a, block_m) - a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids, - num_groups, topk, block_m) + a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids, + num_groups, topk, block_m) inter_out = torch.zeros((a_q.shape[0], N * 2), dtype=torch.bfloat16, @@ -421,7 +421,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - final_out = test_moe_unpermute(out, inv_perm, topk, K, topk_weight) + final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight) return final_out From a5c89072e415149efb554e1025595b0959baaaac Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 20:41:40 +0000 Subject: [PATCH 118/171] merge triton dispatch into standard, add some comments Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/modular_kernel.py | 1 + .../model_executor/layers/fused_moe/pplx_dispatch_combine.py | 5 +---- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index c386d5ec1dcd..9cc8131a5d81 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -102,6 +102,7 @@ def dispatch( - num_experts: The total number of experts in the global expert space. - expert_map: A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. + - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. Returns a tuple of: - quantized + dispatched a. diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 658705515b43..8eac4fd3f5e7 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -47,8 +47,6 @@ def dispatch( assert expert_map is None, "NYI" - # TBD - assert not apply_router_weight_on_input if apply_router_weight_on_input: topk = rank_topk_ids.shape[1] # TODO: this only works for topK=1, will need to update for topK>1 @@ -131,8 +129,7 @@ def combine( assert output.shape[0] <= self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] - # Set weights to 1? - assert not apply_router_weight_on_input + # Set weights to 1 if we did them in dispatch. This is hacky. if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) From 939ef2fc25dc56fc8d64705e7b4e0d40232a0ce5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 20:47:41 +0000 Subject: [PATCH 119/171] format Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 4 +--- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 3 --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 9 +-------- vllm/model_executor/layers/fused_moe/modular_kernel.py | 1 + 4 files changed, 3 insertions(+), 14 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 2f9315f19529..3939f4b7bab1 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -10,9 +10,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - deep_gemm_moe_fp8, - modular_deep_gemm_fused_moe_fp8, - _valid_deep_gemm_shape) + _valid_deep_gemm_shape, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 19ca505a2561..e52751eddf2c 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -9,9 +9,6 @@ from vllm.model_executor.layers.fused_moe.dispatch_combine import ( StandardDispatchCombine) from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache -from vllm.model_executor.layers.fused_moe.dispatch_combine import ( - StandardDispatchCombine -) class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute): diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index ab355c7d53e1..266ba3bfa07a 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -12,13 +12,6 @@ _moe_permute) from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) -from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - _moe_permute, - _moe_unpermute_and_reduce -) -from vllm.model_executor.layers.fused_moe.dispatch_combine import ( - StandardDispatchCombine -) from vllm.utils import round_up logger = init_logger(__name__) @@ -35,7 +28,7 @@ def deep_gemm_block_shape() -> list[int]: def _valid_deep_gemm_shape(M: int, N: int, K: int): align = deep_gemm_block_shape()[0] - return M >= align and N % align == 0 and K % align == 0 + return align <= M and N % align == 0 and K % align == 0 def _valid_deep_gemm(hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 9cc8131a5d81..a3086dee4b30 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -72,6 +72,7 @@ def _moe_problem_size( return E, M, N, K, topk + class FusedMoEQuantizeDispatchCombine(ABC): """ An abstract base class for the [Quantize-Dispatch] and [Combine] steps From 65f4b552a8ef7191c12204834e4cb413b2c19402 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 23:20:14 +0000 Subject: [PATCH 120/171] cleanup for review Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 20 ++-------------- .../layers/fused_moe/fused_moe.py | 24 ------------------- 2 files changed, 2 insertions(+), 42 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 3939f4b7bab1..762d02394086 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) + _valid_deep_gemm_shape, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -477,21 +477,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - if True: - dgm = modular_deep_gemm_fused_moe_fp8() - - def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids): - return dgm(a, - w1, - w2, - topk_weights, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s) - else: - deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 - # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): if M >= 128: @@ -503,8 +488,7 @@ def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) + out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index af14607122fc..941cb753e1c4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1207,30 +1207,6 @@ def fused_experts(hidden_states: torch.Tensor, a2_scale=a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) - elif hidden_states.shape[0] <= envs.VLLM_FUSED_MOE_CHUNK_SIZE: - fe = modular_triton_fused_moe( - use_fp8_w8a8, - use_int8_w8a16, - use_int4_w4a16, - block_shape, - ) - return fe( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace, - activation, - global_num_experts, - expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, - a1_scale, - a2_scale, - ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, From 2672f689a3673d0a81c1529e0b015eaae7f65909 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Apr 2025 23:17:43 +0000 Subject: [PATCH 121/171] hacking Signed-off-by: Bill Nell --- .../layers/fused_moe/__init__.py | 5 +-- vllm/model_executor/layers/fused_moe/layer.py | 34 +++++++++++++++++-- .../layers/fused_moe/pplx_dispatch_combine.py | 6 ++-- 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 9829ccdb384f..b55a3ede12d3 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -38,8 +38,8 @@ def get_config() -> Optional[Dict[str, Any]]: from vllm.model_executor.layers.fused_moe.cutlass_moe import ( cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_moe, fused_topk, get_config_file_name, - grouped_topk) + TritonExperts, fused_experts, fused_moe, fused_topk, + get_config_file_name, grouped_topk) __all__ += [ "fused_moe", @@ -48,4 +48,5 @@ def get_config() -> Optional[Dict[str, Any]]: "get_config_file_name", "grouped_topk", "cutlass_moe_fp8", + "TritonExperts", ] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index be9b2bc4a64d..9233c46ba0c8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -8,6 +8,8 @@ import pplx_kernels as pplx import torch import torch.nn.functional as F +from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, nvshmem_init) from torch.nn.parameter import UninitializedParameter import vllm.envs as envs @@ -23,10 +25,13 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, run_once if current_platform.is_cuda_alike(): - from .fused_moe import fused_experts + #from .pplx_dispatch_combine import PplxDispatchCombine + from .dispatch_combine import StandardDispatchCombine + from .fused_moe import TritonExperts, fused_experts + from .modular_kernel import FusedMoEModularKernel else: fused_experts = None # type: ignore if current_platform.is_tpu(): @@ -405,6 +410,14 @@ def determine_expert_map( return (local_num_experts, expert_map) +@run_once +def pplx_init(rank, world_size): + uid = nvshmem_get_unique_id( + ) if rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, rank, world_size) + + class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. @@ -528,8 +541,23 @@ def __init__( # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. if quant_config is None: + pplx_init(self.dp_rank, self.dp_size) + + moe = MoEConfig( + num_experts=self.global_num_experts, + experts_per_token=0, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + dp_size=self.dp_size, + dp_rank=self.dp_rank, + ep_size=self.ep_size, + ep_rank=self.ep_rank, + #in_dtype = 0, + #out_dtype = 0, + ) + self.quant_method: Optional[QuantizeMethodBase] = ( - UnquantizedFusedMoEMethod()) + UnquantizedFusedMoEMethod(moe)) else: self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 8eac4fd3f5e7..0302524fe1c2 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -69,14 +69,14 @@ def dispatch( expert_num_tokens = torch.empty( num_local_experts, dtype=torch.int32, - device=device, + device=a1.device, ) num_dp = self.world_size // self.dp_size expert_x = torch.empty( (num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]), dtype=a1q.dtype, - device=device, + device=a1.device, ) expert_x_scale: Optional[torch.Tensor] = None @@ -91,7 +91,7 @@ def dispatch( (expert_x.size(2) + block_size - 1) // block_size, ), dtype=torch.float32, - device=device, + device=a1.device, ) # This argument is optional, defaults to indices.shape[0] From a6df5b728dec72f0a0a7ded69cd78bc81e31f81f Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 7 Apr 2025 15:04:28 +0000 Subject: [PATCH 122/171] hacking Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 62 ++++++++++++++++++- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9233c46ba0c8..cb9ff7446827 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import threading +import weakref from abc import abstractmethod from dataclasses import dataclass from enum import Enum @@ -99,13 +101,67 @@ def apply( raise NotImplementedError +class AllToAllCache: + + def __init__(self): + self._cache = {} + self._lock = threading.RLock() # Reentrant lock for thread safety + + def get_or_create(self, **kwargs): + # Create a hashable key from the kwargs + key = tuple(sorted((k, v) for k, v in kwargs.items())) + + with self._lock: + if key in self._cache: + instance, refs = self._cache[key] + new_ref = weakref.ref(object(), + lambda _: self._decrement_ref_count(key)) + refs.append(new_ref) + return instance + else: + # Create new instance + instance = pplx.AllToAll(**kwargs) + # Use a weakref.ref with a callback when reference is collected + refs = [ + weakref.ref(object(), + lambda _: self._decrement_ref_count(key)) + ] + self._cache[key] = (instance, refs) + return instance + + def _decrement_ref_count(self, key): + with self._lock: + if key in self._cache: + instance, refs = self._cache[key] + # Remove dead references + refs = [ref for ref in refs if ref() is not None] + if not refs: + # No more references, clean up the instance + instance.destroy() + del self._cache[key] + else: + # Update refs + self._cache[key] = (instance, refs) + + +# Global singleton +_all_to_all_cache = AllToAllCache() + + +# Factory function as a cleaner interface +def get_all_to_all(**kwargs): + return _all_to_all_cache.get_or_create(**kwargs) + + #TODO: Every change in this class is a broken hack!! @CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def __init__(self, moe: MoEConfig): - self.all_to_all = pplx.AllToAll( + pplx_init(moe.ep_rank, moe.ep_size) + + self.all_to_all = get_all_to_all( max_num_tokens=MOE_DP_CHUNK_SIZE // moe.dp_size, num_experts=moe.num_experts, experts_per_token=moe.experts_per_token, @@ -412,9 +468,11 @@ def determine_expert_map( @run_once def pplx_init(rank, world_size): + print(f"PPLX_INIT {rank} {world_size}") uid = nvshmem_get_unique_id( ) if rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) + print(f"PPLX_INIT UID={uid}") + torch.distributed.broadcast(uid.cuda(), src=0) nvshmem_init(uid, rank, world_size) From bddffe7adb0068d1ba52dfcec9940e63bb84fb86 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 7 Apr 2025 19:41:44 +0000 Subject: [PATCH 123/171] init stuff Signed-off-by: Bill Nell --- vllm/distributed/parallel_state.py | 18 +++++++- vllm/model_executor/layers/fused_moe/layer.py | 44 ++++++++++--------- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index cb9658ce1004..0bb4835939af 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -42,7 +42,7 @@ from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, - supports_custom_op) + run_once, supports_custom_op) @dataclass @@ -912,6 +912,20 @@ def init_distributed_environment( "world group already initialized with a different world size") +@run_once +def pplx_init(rank, world_size): + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, nvshmem_init) + print(f"PPLX_INIT {rank} {world_size}") + uid = nvshmem_get_unique_id( + ) if rank == 0 else nvshmem_alloc_empty_unique_id() + uid_gpu = uid.cuda() + get_world_group().broadcast(uid_gpu, src=0) + print(f"PPLX_INIT UID={uid_gpu}") + uid = uid_gpu.to(device='cpu') + nvshmem_init(uid, rank, world_size) + + def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, @@ -1006,6 +1020,8 @@ def initialize_model_parallel( "DP rank %s, PP rank %s, TP rank %s", rank, world_size, _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) + pplx_init(rank, world_size) + def ensure_model_parallel_initialized( tensor_model_parallel_size: int, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index cb9ff7446827..06e29945cac6 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -10,8 +10,6 @@ import pplx_kernels as pplx import torch import torch.nn.functional as F -from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_get_unique_id, nvshmem_init) from torch.nn.parameter import UninitializedParameter import vllm.envs as envs @@ -27,13 +25,13 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import direct_register_custom_op, run_once +from vllm.utils import direct_register_custom_op if current_platform.is_cuda_alike(): - #from .pplx_dispatch_combine import PplxDispatchCombine from .dispatch_combine import StandardDispatchCombine from .fused_moe import TritonExperts, fused_experts from .modular_kernel import FusedMoEModularKernel + from .pplx_dispatch_combine import PplxDispatchCombine else: fused_experts = None # type: ignore if current_platform.is_tpu(): @@ -101,7 +99,7 @@ def apply( raise NotImplementedError -class AllToAllCache: +class AllToAllCacheThreadSafe: def __init__(self): self._cache = {} @@ -120,6 +118,7 @@ def get_or_create(self, **kwargs): return instance else: # Create new instance + print("CREATE AllToAll") instance = pplx.AllToAll(**kwargs) # Use a weakref.ref with a callback when reference is collected refs = [ @@ -144,6 +143,25 @@ def _decrement_ref_count(self, key): self._cache[key] = (instance, refs) +class AllToAllCache: + + def __init__(self): + self._cache = {} + + def get_or_create(self, **kwargs): + # Create a hashable key from the kwargs + key = tuple(sorted((k, v) for k, v in kwargs.items())) + + if key in self._cache: + return self._cache[key] + else: + # Create new instance + print("CREATE AllToAll") + instance = pplx.AllToAll(**kwargs) + self._cache[key] = instance + return instance + + # Global singleton _all_to_all_cache = AllToAllCache() @@ -159,8 +177,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def __init__(self, moe: MoEConfig): - pplx_init(moe.ep_rank, moe.ep_size) - self.all_to_all = get_all_to_all( max_num_tokens=MOE_DP_CHUNK_SIZE // moe.dp_size, num_experts=moe.num_experts, @@ -301,7 +317,7 @@ def forward_cuda( e_score_correction_bias=e_score_correction_bias) return fused_experts( - hidden_states=x, + a1=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, @@ -466,16 +482,6 @@ def determine_expert_map( return (local_num_experts, expert_map) -@run_once -def pplx_init(rank, world_size): - print(f"PPLX_INIT {rank} {world_size}") - uid = nvshmem_get_unique_id( - ) if rank == 0 else nvshmem_alloc_empty_unique_id() - print(f"PPLX_INIT UID={uid}") - torch.distributed.broadcast(uid.cuda(), src=0) - nvshmem_init(uid, rank, world_size) - - class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. @@ -599,8 +605,6 @@ def __init__( # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. if quant_config is None: - pplx_init(self.dp_rank, self.dp_size) - moe = MoEConfig( num_experts=self.global_num_experts, experts_per_token=0, From 1813ae488a307374cb6a1cbaab4fbf4ca7132e2f Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 7 Apr 2025 19:52:03 +0000 Subject: [PATCH 124/171] call super ctor + fix random stuff Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 06e29945cac6..564d033ead77 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -177,7 +177,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def __init__(self, moe: MoEConfig): - self.all_to_all = get_all_to_all( + super().__init__() + self._moe = moe + self._all_to_all = get_all_to_all( max_num_tokens=MOE_DP_CHUNK_SIZE // moe.dp_size, num_experts=moe.num_experts, experts_per_token=moe.experts_per_token, From d50afb6cf82ad85810ad1f3e175cc0a0815e742f Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 7 Apr 2025 15:53:52 -0400 Subject: [PATCH 125/171] fix use_ep bug Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 564d033ead77..f9fea86c512e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -548,7 +548,7 @@ def __init__( # Use expert parallelism instead of tensor parallelism? vllm_config = get_current_vllm_config() use_ep = (vllm_config.parallel_config.enable_expert_parallel - and self.tp_size * self.dp_size > 1) + and (self.tp_size * self.dp_size) > 1) # For smuggling this layer into the fused moe custom op self.use_direct_call = self.dp_size == 1 From 207a373886e0eb9bc2004668a740a3735436df66 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 7 Apr 2025 16:30:28 -0400 Subject: [PATCH 126/171] Fix dp_size Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f9fea86c512e..e8435214faff 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -185,7 +185,7 @@ def __init__(self, moe: MoEConfig): experts_per_token=moe.experts_per_token, rank=moe.ep_rank, world_size=moe.ep_size, - dp_size=moe.dp_size, + dp_size=moe.ep_size // moe.dp_size, hidden_dim=moe.hidden_dim, hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, hidden_dim_scale_bytes=0, From ea821e3c6512282e2d6ca43cd329e21683d0c725 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 7 Apr 2025 16:39:36 -0400 Subject: [PATCH 127/171] add comment Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e8435214faff..b06fdb274709 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -185,7 +185,7 @@ def __init__(self, moe: MoEConfig): experts_per_token=moe.experts_per_token, rank=moe.ep_rank, world_size=moe.ep_size, - dp_size=moe.ep_size // moe.dp_size, + dp_size=moe.ep_size // moe.dp_size, # dp_size actually means TP. hidden_dim=moe.hidden_dim, hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, hidden_dim_scale_bytes=0, From e4acd18beeebff846d40f366efa51515e417c97c Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 7 Apr 2025 16:48:50 -0400 Subject: [PATCH 128/171] fixes Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 0302524fe1c2..86fa17561f20 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -103,6 +103,9 @@ def dispatch( # TODO: optimize this? indices = rank_topk_ids.to(dtype=torch.uint32) + # TODO: optimize this? + rank_topk_ids = rank_topk_ids.to(dtype=torch.uint32) + self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, From 353151ed8cc59ebcf87da7c09a49f859f28a1d56 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 7 Apr 2025 20:46:49 +0000 Subject: [PATCH 129/171] get a bit further Signed-off-by: Bill Nell --- vllm/distributed/parallel_state.py | 8 ++++++-- vllm/model_executor/layers/fused_moe/layer.py | 5 +++++ .../layers/fused_moe/pplx_dispatch_combine.py | 3 --- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 0bb4835939af..088dc49bf3fd 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -35,6 +35,9 @@ import torch import torch.distributed from torch.distributed import Backend, ProcessGroup +from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, nvshmem_init, + nvshmem_finalize) import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( @@ -914,8 +917,6 @@ def init_distributed_environment( @run_once def pplx_init(rank, world_size): - from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_get_unique_id, nvshmem_init) print(f"PPLX_INIT {rank} {world_size}") uid = nvshmem_get_unique_id( ) if rank == 0 else nvshmem_alloc_empty_unique_id() @@ -1097,6 +1098,8 @@ def get_tensor_model_parallel_rank(): def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP + nvshmem_finalize() + if _TP: _TP.destroy() _TP = None @@ -1112,6 +1115,7 @@ def destroy_model_parallel(): _DP = None + def destroy_distributed_environment(): global _WORLD if _WORLD: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b06fdb274709..6d974333aba9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -161,6 +161,11 @@ def get_or_create(self, **kwargs): self._cache[key] = instance return instance + def clear(): + for k, v in self._cache.items(): + v.destroy() + del self._cache + # Global singleton _all_to_all_cache = AllToAllCache() diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 86fa17561f20..0302524fe1c2 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -103,9 +103,6 @@ def dispatch( # TODO: optimize this? indices = rank_topk_ids.to(dtype=torch.uint32) - # TODO: optimize this? - rank_topk_ids = rank_topk_ids.to(dtype=torch.uint32) - self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, From 70fc2a8db2117dc285da801b912c55383c613be8 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 9 Apr 2025 23:06:46 +0000 Subject: [PATCH 130/171] hacking in dispatch_combine Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 10 +- vllm/model_executor/layers/fused_moe/layer.py | 122 +++++++++++++----- .../layers/fused_moe/modular_kernel.py | 18 ++- .../layers/fused_moe/pplx_dispatch_combine.py | 6 + .../layers/fused_moe/triton_deep_gemm_moe.py | 104 +++++++++++++++ .../model_executor/layers/quantization/fp8.py | 39 +++++- 6 files changed, 252 insertions(+), 47 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 941cb753e1c4..6cde135570ed 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1666,6 +1666,9 @@ def apply( raise ValueError( f"Unsupported compute_type: {hidden_states.dtype}") + #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}") + #print(f"BLOCK_M = {self.block_m}") + # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 intermediate_cache1 = _resize_cache(workspace13, @@ -1676,8 +1679,11 @@ def apply( (num_tokens, top_k_num, K)) sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], - global_num_experts, expert_map)) + moe_align_block_size( + topk_ids, + config['BLOCK_SIZE_M'] if self.block_m is None else self.block_m, + global_num_experts, expert_map + )) invoke_fused_moe_kernel(hidden_states, w1, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6d974333aba9..a8616af3aa72 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -30,7 +30,7 @@ if current_platform.is_cuda_alike(): from .dispatch_combine import StandardDispatchCombine from .fused_moe import TritonExperts, fused_experts - from .modular_kernel import FusedMoEModularKernel + from .modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine from .pplx_dispatch_combine import PplxDispatchCombine else: fused_experts = None # type: ignore @@ -77,6 +77,9 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError + def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: + return False + @abstractmethod def apply( self, @@ -118,7 +121,6 @@ def get_or_create(self, **kwargs): return instance else: # Create new instance - print("CREATE AllToAll") instance = pplx.AllToAll(**kwargs) # Use a weakref.ref with a callback when reference is collected refs = [ @@ -183,18 +185,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: MoEConfig): super().__init__() - self._moe = moe - self._all_to_all = get_all_to_all( - max_num_tokens=MOE_DP_CHUNK_SIZE // moe.dp_size, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, - rank=moe.ep_rank, - world_size=moe.ep_size, - dp_size=moe.ep_size // moe.dp_size, # dp_size actually means TP. - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - hidden_dim_scale_bytes=0, - ) + self.fused_experts = fused_experts + self.moe = moe def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -293,6 +285,26 @@ def apply( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) + # Maybe extra args + def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: + block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) + print(f"block_m = {block_m}") + + experts = TritonExperts( + use_fp8_w8a8 = False, + use_int8_w8a16 = False, + use_int4_w4a16 = False, + block_shape = None, + block_m = None, #block_m, + ) + + self.fused_experts = FusedMoEModularKernel( + dispatch_combine, + experts, + ) + + return True + def forward_cuda( self, layer: torch.nn.Module, @@ -323,8 +335,8 @@ def forward_cuda( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return fused_experts( - a1=x, + return self.fused_experts( + hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, @@ -333,7 +345,8 @@ def forward_cuda( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, - expert_map=expert_map) + expert_map=expert_map, + ) def forward_cpu( self, @@ -609,27 +622,67 @@ def __init__( from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) + moe = MoEConfig( + num_experts=self.global_num_experts, + experts_per_token=top_k, # ? must be same as topk_ids.shape[1] + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + dp_size=self.dp_size, + dp_rank=self.dp_rank, + ep_size=self.ep_size, + ep_rank=self.ep_rank, + #in_dtype = 0, + #out_dtype = 0, + ) + # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. + quant_method: Optional[FusedMoEMethodBase] = None + if quant_config is None: - moe = MoEConfig( - num_experts=self.global_num_experts, - experts_per_token=0, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - dp_size=self.dp_size, - dp_rank=self.dp_rank, - ep_size=self.ep_size, - ep_rank=self.ep_rank, - #in_dtype = 0, - #out_dtype = 0, + quant_method = UnquantizedFusedMoEMethod(moe) + else: + # moe? + # TODO: setup dispatcher on FusedMoE. callees of this + # function can grab dispatcher from there? Or add + # supports_dispatcher/set_dispatcher method on FusedMoeMethodBase + quant_method = quant_config.get_quant_method(self, prefix) + assert isinstance(quant_method, FusedMoEMethodBase) + + assert quant_method is not None + self.quant_method = quant_method + + # TODO: move to method? + if self.dp_size > 1: + all_to_all = get_all_to_all( + max_num_tokens=MOE_DP_CHUNK_SIZE, # // moe.dp_size, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # has to be same as topk_ids.shape[1] + rank=moe.ep_rank, + world_size=moe.ep_size, + dp_size=moe.ep_size // moe.dp_size, # dp_size actually means TP. + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + hidden_dim_scale_bytes=0, ) - self.quant_method: Optional[QuantizeMethodBase] = ( - UnquantizedFusedMoEMethod(moe)) - else: - self.quant_method = quant_config.get_quant_method(self, prefix) - assert self.quant_method is not None + if False: + dispatch_combine = PplxDispatchCombine( + all_to_all, + MOE_DP_CHUNK_SIZE, + moe.ep_size, + moe.dp_size, + moe.in_dtype, + ) + else: + dispatch_combine = StandardDispatchCombine( + moe.in_dtype, + quant_config.weight_block_size if quant_config is not None else None, + ) + + success = self.quant_method.set_dispatch_combine(dispatch_combine) + if not success: + logger.warning("DP+EP not supported for %s.", type(self.quant_method)) self.apply_router_weight_on_input = apply_router_weight_on_input moe_quant_params = { @@ -975,6 +1028,7 @@ def forward(self, hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): + max_tokens_across_dp = get_forward_context( ).dp_metadata.max_tokens_across_dp cu_tokens_across_dp_cpu = get_forward_context( @@ -982,6 +1036,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, num_tokens_across_dp = get_forward_context( ).dp_metadata.num_tokens_across_dp + #print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}") + #In this function we define two ranges: # 1. chunk_range - The current iteration of the loops's range over the DP world tokens # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a3086dee4b30..f7b3f7899dd1 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -60,15 +60,19 @@ def _moe_problem_size( E, N, _ = w1.shape K = w2.shape[1] - assert a1.dim() == 2 assert topk_ids.dim() == 2 - # Make sure we are using the correct a1 (pre-permute). - assert topk_ids.shape[0] == a1.shape[ - 0], f"{topk_ids.shape[0]} != {a1.shape[0]}" - - M = a1.shape[0] topk = topk_ids.shape[1] + if a1.dim() == 2: + # Make sure we are using the correct a1 (pre-permute). + assert topk_ids.shape[0] == a1.shape[0], \ + f"{topk_ids.shape[0]} != {a1.shape[0]}" + M = a1.shape[0] + else: + assert a1.dim() == 3 + assert E == a1.shape[0] + M = a1.shape[1] # This is max_num_tokens + return E, M, N, K, topk @@ -311,6 +315,8 @@ def forward( a1 = hidden_states E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) + #print(f"INIT shape: E={E}, M={M}, N={N}, K={K}, top_k={top_k}") + if global_num_experts == -1: global_num_experts = E diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 0302524fe1c2..fa717c40c774 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -30,6 +30,10 @@ def __init__(self, self.dp_size = dp_size self.rank = rank self.quant_dtype = quant_dtype + print(f"max_num_tokens = {max_num_tokens}") + print(f"dp_num_tokens = {self.dp_num_tokens}") + print(f"world_size = {world_size}") + print(f"dp_size = {dp_size}") def dispatch( self, @@ -71,6 +75,7 @@ def dispatch( dtype=torch.int32, device=a1.device, ) + expert_num_tokens.fill_(-1) num_dp = self.world_size // self.dp_size expert_x = torch.empty( @@ -78,6 +83,7 @@ def dispatch( dtype=a1q.dtype, device=a1.device, ) + expert_x.fill_(torch.nan) expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py new file mode 100644 index 000000000000..f3a13e44296d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +import importlib.util +from typing import List, Optional, Tuple + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + DeepGemmExperts, + _valid_deep_gemm_shape, + _valid_deep_gemm, +) +from vllm.model_executor.layers.fused_moe.fused_moe import TritonExpert + +class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, + block_m: Optional[int] = None, + allow_deep_gemm: bool = False + ): + super().__init__() + self.triton_expert = TritonExpert( + use_fp8_w8a8, + use_int4_w4a16, + use_int8_w8a16, + block_shape, + block_m + ) + self.deep_gemm_expert = DeepGemmExperts() + self.allow_deep_gemm = allow_deep_gemm + self.use_fp8_w8a8 = use_fp8_w8a8 + + def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, + topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: + # Note: the deep gemm workspaces are strictly larger than the triton + # workspaces so we can be pessimistic here and allocate for DeepGemm + # even if we fall back to triton later, e.g. if expert maps are set. + if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): + return self.deep_gemm_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts) + else: + return self.triton_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: + N = w1.shape[1] + if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 + and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): + return self.deep_gemm_expert( + hidden_states, + w1, + w2, + topk_ids, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1q_scale, + a2_scale, + workspace13, + workspace2, + ) + else: + return self.triton_expert( + hidden_states, + w1, + w2, + topk_ids, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1q_scale, + a2_scale, + workspace13, + workspace2, + ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5e0386ea6bde..70feeb1167f5 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import functools import importlib.util from typing import Any, Callable, Dict, List, Optional @@ -10,6 +11,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, @@ -441,6 +443,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): """ def __init__(self, quant_config: Fp8Config): + from vllm.model_executor.layers.fused_moe import fused_experts self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None self.allow_deep_gemm = allow_deep_gemm @@ -458,6 +461,11 @@ def __init__(self, quant_config: Fp8Config): logger.warning_once( "DeepGemm not supported on the current platform.") + self.fused_experts = functools.partial( + fused_experts, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm) + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -769,6 +777,29 @@ def process_weights_after_loading(self, layer: Module) -> None: requires_grad=False) return + # Maybe extra args + def set_dispatch_combine(self, dispatch_combine: mk.FusedMoEQuantizeDispatchCombine) -> bool: + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import TritonOrDeepGemmExperts + + #block_m = MOE_DP_CHUNK_SIZE * (moe.ep_size // moe.dp_size) + #print(f"block_m = {block_m}") + + experts = TritonOrDeepGemmExperts( + use_fp8_w8a8 = True, + use_int8_w8a16 = False, + use_int4_w4a16 = False, + block_shape = self.quant_config.weight_block_size, + block_m = None, # TODO + allow_deep_gemm=self.allow_deep_gemm, + ) + + self.fused_experts = mk.FusedMoEModularKernel( + dispatch_combine, + experts, + ) + + return True + def apply( self, layer: torch.nn.Module, @@ -787,8 +818,6 @@ def apply( apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -802,8 +831,8 @@ def apply( e_score_correction_bias=e_score_correction_bias, ) - return fused_experts( - x, + return self.fused_experts( + hidden_states=x, layer.w13_weight, layer.w2_weight, topk_weights=topk_weights, @@ -820,8 +849,6 @@ def apply( if self.block_quant else layer.w2_weight_scale), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm, ) From 3b319a12e0566ed11776d101c2c1bfc16f1ade8d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 10 Apr 2025 14:47:37 +0000 Subject: [PATCH 131/171] hook up some wires Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 130 +++++++----------- .../layers/fused_moe/pplx_dispatch_combine.py | 2 + 2 files changed, 54 insertions(+), 78 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a8616af3aa72..c114dd79cde2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -57,8 +57,10 @@ class MoEConfig: ep_size: int ep_rank: int - in_dtype: torch.dtype = torch.bfloat16 - out_dtype: torch.dtype = torch.bfloat16 + in_dtype: torch.dtype + out_dtype: torch.dtype + + # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 @@ -102,10 +104,10 @@ def apply( raise NotImplementedError -class AllToAllCacheThreadSafe: +class AllToAllCache: def __init__(self): - self._cache = {} + self._cache = weakref.WeakValueDictionary() self._lock = threading.RLock() # Reentrant lock for thread safety def get_or_create(self, **kwargs): @@ -113,61 +115,12 @@ def get_or_create(self, **kwargs): key = tuple(sorted((k, v) for k, v in kwargs.items())) with self._lock: - if key in self._cache: - instance, refs = self._cache[key] - new_ref = weakref.ref(object(), - lambda _: self._decrement_ref_count(key)) - refs.append(new_ref) - return instance - else: - # Create new instance + instance = self._cache.get(key) + if instance is None: instance = pplx.AllToAll(**kwargs) - # Use a weakref.ref with a callback when reference is collected - refs = [ - weakref.ref(object(), - lambda _: self._decrement_ref_count(key)) - ] - self._cache[key] = (instance, refs) - return instance - - def _decrement_ref_count(self, key): - with self._lock: - if key in self._cache: - instance, refs = self._cache[key] - # Remove dead references - refs = [ref for ref in refs if ref() is not None] - if not refs: - # No more references, clean up the instance - instance.destroy() - del self._cache[key] - else: - # Update refs - self._cache[key] = (instance, refs) - - -class AllToAllCache: - - def __init__(self): - self._cache = {} - - def get_or_create(self, **kwargs): - # Create a hashable key from the kwargs - key = tuple(sorted((k, v) for k, v in kwargs.items())) - - if key in self._cache: - return self._cache[key] - else: - # Create new instance - print("CREATE AllToAll") - instance = pplx.AllToAll(**kwargs) - self._cache[key] = instance + self._cache[key] = instance return instance - def clear(): - for k, v in self._cache.items(): - v.destroy() - del self._cache - # Global singleton _all_to_all_cache = AllToAllCache() @@ -622,6 +575,8 @@ def __init__( from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) + print(f"params dtype= {params_dtype}") + moe = MoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, # ? must be same as topk_ids.shape[1] @@ -631,8 +586,8 @@ def __init__( dp_rank=self.dp_rank, ep_size=self.ep_size, ep_rank=self.ep_rank, - #in_dtype = 0, - #out_dtype = 0, + in_dtype = params_dtype, # this is probably not right, where to get? + out_dtype = params_dtype, # ditto. ) # Note: get_quant_method will look at the layer's local_num_experts @@ -642,10 +597,6 @@ def __init__( if quant_config is None: quant_method = UnquantizedFusedMoEMethod(moe) else: - # moe? - # TODO: setup dispatcher on FusedMoE. callees of this - # function can grab dispatcher from there? Or add - # supports_dispatcher/set_dispatcher method on FusedMoeMethodBase quant_method = quant_config.get_quant_method(self, prefix) assert isinstance(quant_method, FusedMoEMethodBase) @@ -654,24 +605,47 @@ def __init__( # TODO: move to method? if self.dp_size > 1: - all_to_all = get_all_to_all( - max_num_tokens=MOE_DP_CHUNK_SIZE, # // moe.dp_size, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # has to be same as topk_ids.shape[1] - rank=moe.ep_rank, - world_size=moe.ep_size, - dp_size=moe.ep_size // moe.dp_size, # dp_size actually means TP. - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - hidden_dim_scale_bytes=0, - ) + if True: + max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size + world_size = moe.ep_size + dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. + rank = moe.ep_rank + + print(f"max num = {max_num_tokens}") + print(f"world size = {world_size}") + print(f"moe ep size = {moe.ep_size}") + print(f"moe dp size = {moe.dp_size}") + print(f"dp size = {dp_size}") + print(f"rank= {rank}") + + all_to_all = get_all_to_all( + max_num_tokens=max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=rank, + world_size=world_size, + dp_size=dp_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + # For blocked per token: set to ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + hidden_dim_scale_bytes=( + 0 + if moe.in_dtype.itemsize != 1 + else ( + (moe.hidden_dim + moe.block_size - 1) + // moe.block_size + * torch.float32.itemsize + ) + ) + ) - if False: dispatch_combine = PplxDispatchCombine( all_to_all, - MOE_DP_CHUNK_SIZE, - moe.ep_size, - moe.dp_size, + max_num_tokens, + world_size, + dp_size, + rank, # just for debugging moe.in_dtype, ) else: @@ -1036,7 +1010,7 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, num_tokens_across_dp = get_forward_context( ).dp_metadata.num_tokens_across_dp - #print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}") + print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}") #In this function we define two ranges: # 1. chunk_range - The current iteration of the loops's range over the DP world tokens diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index fa717c40c774..fd1fbb167514 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -9,6 +9,8 @@ moe_kernel_quantize_input) +logger = init_logger(__name__) + # Note use: layer.get_all_to_all() to get an AllToAll instance # The max_num_tokens, world_size and dp_size must be the same # as the ones used to create the AllToAll. From 792d7518559b824b2b1bc5e84b607afec3bc490e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 10 Apr 2025 21:48:22 +0000 Subject: [PATCH 132/171] seems to be working Signed-off-by: Bill Nell --- .../layers/fused_moe/deep_gemm_moe.py | 4 +- .../layers/fused_moe/fused_moe.py | 23 +++-- vllm/model_executor/layers/fused_moe/layer.py | 85 ++++++++++--------- .../layers/fused_moe/modular_kernel.py | 6 +- .../layers/fused_moe/pplx_dispatch_combine.py | 11 ++- 5 files changed, 70 insertions(+), 59 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 266ba3bfa07a..a694c53d9f36 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -134,7 +134,9 @@ def apply( dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids) - self.activation(activation, workspace2, workspace1.view(-1, N)) + self.activation(activation, + workspace2, + workspace1.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6cde135570ed..6a57b117ce0c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1678,12 +1678,20 @@ def apply( intermediate_cache3 = _resize_cache(workspace13, (num_tokens, top_k_num, K)) - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size( - topk_ids, - config['BLOCK_SIZE_M'] if self.block_m is None else self.block_m, - global_num_experts, expert_map - )) + if hidden_states.dim() == 2: #block_m is None: + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size( + topk_ids, + config['BLOCK_SIZE_M'], + global_num_experts, expert_map + )) + else: + stride = hidden_states.shape[1] + sorted_token_ids = torch.arange(0, hidden_states.shape[0], device=hidden_states.device, dtype=torch.int) + sorted_token_ids = sorted_token_ids * stride + expert_ids = torch.logical_not(torch.isnan(hidden_states)).sum(dim=(1,2)).nonzero() + num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) invoke_fused_moe_kernel(hidden_states, w1, @@ -1706,7 +1714,8 @@ def apply( per_channel_quant=self.per_channel_quant, block_shape=self.block_shape) - self.activation(activation, intermediate_cache2, + self.activation(activation, + intermediate_cache2, intermediate_cache1.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c114dd79cde2..4904592467d7 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -241,7 +241,7 @@ def apply( # Maybe extra args def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) - print(f"block_m = {block_m}") + #print(f"block_m = {block_m}") experts = TritonExperts( use_fp8_w8a8 = False, @@ -550,8 +550,8 @@ def __init__( self.ep_size = 1 self.local_num_experts = self.global_num_experts self.expert_map = None + #self.global_num_experts = num_experts redundant? self.top_k = top_k - self.global_num_experts = num_experts assert intermediate_size % self.tp_size == 0 self.hidden_size = hidden_size @@ -571,11 +571,12 @@ def __init__( if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") + if current_platform.is_hpu(): from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) - print(f"params dtype= {params_dtype}") + #print(f"params dtype= {params_dtype}") moe = MoEConfig( num_experts=self.global_num_experts, @@ -604,13 +605,13 @@ def __init__( self.quant_method = quant_method # TODO: move to method? - if self.dp_size > 1: - if True: - max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size - world_size = moe.ep_size - dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. - rank = moe.ep_rank + if False and self.dp_size > 1: + max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size + world_size = moe.ep_size + dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. + rank = moe.ep_rank + if False: print(f"max num = {max_num_tokens}") print(f"world size = {world_size}") print(f"moe ep size = {moe.ep_size}") @@ -618,45 +619,45 @@ def __init__( print(f"dp size = {dp_size}") print(f"rank= {rank}") - all_to_all = get_all_to_all( - max_num_tokens=max_num_tokens, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk - rank=rank, - world_size=world_size, - dp_size=dp_size, - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - # For blocked per token: set to ceil_div(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=( - 0 - if moe.in_dtype.itemsize != 1 - else ( - (moe.hidden_dim + moe.block_size - 1) - // moe.block_size - * torch.float32.itemsize - ) + all_to_all = get_all_to_all( + max_num_tokens=max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=rank, + world_size=world_size, + dp_size=dp_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + # For blocked per token: set to ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + hidden_dim_scale_bytes=( + 0 + if moe.in_dtype.itemsize != 1 + else ( + (moe.hidden_dim + moe.block_size - 1) + // moe.block_size + * torch.float32.itemsize ) ) + ) - dispatch_combine = PplxDispatchCombine( - all_to_all, - max_num_tokens, - world_size, - dp_size, - rank, # just for debugging - moe.in_dtype, - ) - else: - dispatch_combine = StandardDispatchCombine( - moe.in_dtype, - quant_config.weight_block_size if quant_config is not None else None, - ) + dispatch_combine = PplxDispatchCombine( + all_to_all, + max_num_tokens, + world_size, + dp_size, + rank, # just for debugging + moe.in_dtype, + ) success = self.quant_method.set_dispatch_combine(dispatch_combine) if not success: logger.warning("DP+EP not supported for %s.", type(self.quant_method)) + else: + dispatch_combine = StandardDispatchCombine( + moe.in_dtype, + quant_config.weight_block_size if quant_config is not None else None, + ) self.apply_router_weight_on_input = apply_router_weight_on_input moe_quant_params = { @@ -1010,7 +1011,7 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, num_tokens_across_dp = get_forward_context( ).dp_metadata.num_tokens_across_dp - print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}") + #print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}") #In this function we define two ranges: # 1. chunk_range - The current iteration of the loops's range over the DP world tokens diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index f7b3f7899dd1..a8b8ba652373 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -60,9 +60,6 @@ def _moe_problem_size( E, N, _ = w1.shape K = w2.shape[1] - assert topk_ids.dim() == 2 - topk = topk_ids.shape[1] - if a1.dim() == 2: # Make sure we are using the correct a1 (pre-permute). assert topk_ids.shape[0] == a1.shape[0], \ @@ -73,6 +70,9 @@ def _moe_problem_size( assert E == a1.shape[0] M = a1.shape[1] # This is max_num_tokens + assert topk_ids.dim() == 2 + topk = topk_ids.shape[1] + return E, M, N, K, topk diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index fd1fbb167514..983cc894ffec 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -32,10 +32,6 @@ def __init__(self, self.dp_size = dp_size self.rank = rank self.quant_dtype = quant_dtype - print(f"max_num_tokens = {max_num_tokens}") - print(f"dp_num_tokens = {self.dp_num_tokens}") - print(f"world_size = {world_size}") - print(f"dp_size = {dp_size}") def dispatch( self, @@ -77,7 +73,7 @@ def dispatch( dtype=torch.int32, device=a1.device, ) - expert_num_tokens.fill_(-1) + expert_num_tokens.fill_(-1) # debugging remove num_dp = self.world_size // self.dp_size expert_x = torch.empty( @@ -85,7 +81,7 @@ def dispatch( dtype=a1q.dtype, device=a1.device, ) - expert_x.fill_(torch.nan) + expert_x.fill_(torch.nan) # debugging remove expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: @@ -146,3 +142,6 @@ def combine( weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m) + + #print("END COMBINE") + From be24517b298b7989ea208b0d18b452eb0c786bf8 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 11 Apr 2025 20:33:42 +0000 Subject: [PATCH 133/171] wip Signed-off-by: Bill Nell --- vllm/distributed/parallel_state.py | 1 + .../layers/fused_moe/fused_moe.py | 16 +++++++++----- vllm/model_executor/layers/fused_moe/layer.py | 6 +++-- .../layers/fused_moe/modular_kernel.py | 5 +++++ .../layers/fused_moe/pplx_dispatch_combine.py | 22 ++++++++++++++----- 5 files changed, 37 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 088dc49bf3fd..59ca9899ba0f 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1098,6 +1098,7 @@ def get_tensor_model_parallel_rank(): def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP + nvshmem_finalize() if _TP: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6a57b117ce0c..2b4b725065fa 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1686,12 +1686,18 @@ def apply( global_num_experts, expert_map )) else: - stride = hidden_states.shape[1] - sorted_token_ids = torch.arange(0, hidden_states.shape[0], device=hidden_states.device, dtype=torch.int) - sorted_token_ids = sorted_token_ids * stride - expert_ids = torch.logical_not(torch.isnan(hidden_states)).sum(dim=(1,2)).nonzero() - num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int) + #stride = hidden_states.shape[1] + sorted_token_ids = torch.arange(0, num_tokens*hidden_states.shape[1], device=hidden_states.device, dtype=torch.int) + sorted_token_ids = sorted_token_ids.flatten() + nans = torch.isnan(hidden_states).sum(dim=(1,2)) + expert_ids = torch.where((nans > 0).flatten(), -1, torch.arange(0, nans.numel(), device=hidden_states.device, dtype=torch.int32)) + #expert_ids = torch.repeat_interleave(expert_ids, hidden_states.shape[1], dim=0) + #print(f"EXPERT_IDS {nans.shape} {expert_ids}") + #num_tokens_post_padded = torch.tensor([num_tokens], device=hidden_states.device, dtype=torch.int32) + num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int32) + num_tokens_post_padded.fill_(num_tokens) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + #print(f"P = {sorted_token_ids}, {hidden_states.shape}") invoke_fused_moe_kernel(hidden_states, w1, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4904592467d7..15ed0ee30e83 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -116,7 +116,7 @@ def get_or_create(self, **kwargs): with self._lock: instance = self._cache.get(key) - if instance is None: + if True or instance is None: instance = pplx.AllToAll(**kwargs) self._cache[key] = instance return instance @@ -605,7 +605,7 @@ def __init__( self.quant_method = quant_method # TODO: move to method? - if False and self.dp_size > 1: + if self.dp_size > 1: max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size world_size = moe.ep_size dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. @@ -1029,6 +1029,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] + print(f"loop {chunk_start}:{chunk_end}") + cu_tokens_across_dp_this_iter = torch.cumsum( num_tokens_remaining_across_dp.clamp( max=moe_dp_chunk_size_per_rank), diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a8b8ba652373..76ece80ba474 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -312,6 +312,9 @@ def forward( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ + from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank) + print(f"START {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") + a1 = hidden_states E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) @@ -361,4 +364,6 @@ def forward( self.dispatch_combine.combine(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) + print(f"DONE {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") + return output diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 983cc894ffec..223b5d3d2aae 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -46,6 +46,8 @@ def dispatch( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: # Is this always going to be a1.device? device = a1.device + num_tokens = a1.shape[0] # M + hidden_dim = a1.shape[-1] # K assert expert_map is None, "NYI" @@ -71,7 +73,7 @@ def dispatch( expert_num_tokens = torch.empty( num_local_experts, dtype=torch.int32, - device=a1.device, + device=device, ) expert_num_tokens.fill_(-1) # debugging remove @@ -79,7 +81,7 @@ def dispatch( expert_x = torch.empty( (num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]), dtype=a1q.dtype, - device=a1.device, + device=device, ) expert_x.fill_(torch.nan) # debugging remove @@ -95,7 +97,7 @@ def dispatch( (expert_x.size(2) + block_size - 1) // block_size, ), dtype=torch.float32, - device=a1.device, + device=device, ) # This argument is optional, defaults to indices.shape[0] @@ -105,7 +107,7 @@ def dispatch( bound_m = None # TODO: optimize this? - indices = rank_topk_ids.to(dtype=torch.uint32) + indices = rank_topk_ids.to(dtype=torch.uint32).to(device) self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, @@ -126,8 +128,17 @@ def combine( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, ) -> None: + device = fused_expert_output.device + #device = torch.device("cuda", self.rank) + #device = get_dp_group().device + #assert fused_expert_output.device == device + + print(f"COMBINE START {self.rank}") + # This argument is optional #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens + #num_tokens = fused_expert_output.shape[0] # M + #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) bound_m = None assert output.shape[0] <= self.max_num_tokens @@ -143,5 +154,4 @@ def combine( expert_y=fused_expert_output, bound_m=bound_m) - #print("END COMBINE") - + print(f"COMBINE END {self.rank}") From 16092a5f24bb13502feb238873cc7c502fe05268 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 14 Apr 2025 21:35:58 +0000 Subject: [PATCH 134/171] batched moe test Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 138 ++++++++++++++++++++++++++++- vllm/distributed/parallel_state.py | 33 +++++-- 2 files changed, 161 insertions(+), 10 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index c58ddbb74e38..2ce0fa2b92aa 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -14,7 +14,7 @@ from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, torch_moe_single) from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_moe, fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) @@ -25,6 +25,7 @@ from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.model_executor.layers.activation import SiluAndMul NUM_EXPERTS = [8, 64] EP_SIZE = [1, 4] @@ -106,6 +107,141 @@ def test_fused_moe( rtol=0) +def batch_by_experts( + a: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int +) -> torch.Tensor: + #print(topk_ids.shape, topk_ids) + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a.shape[0] + + tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, device=a.device) + for i in range(topk_ids.shape[0]): + for j in range(topk_ids.shape[1]): + expert_id = topk_ids[i, j] + tokens_per_expert[expert_id] = tokens_per_expert[expert_id] + 1 + + #print(f"token_per_expert {tokens_per_expert.max()}") + max_num_tokens = tokens_per_expert.max() + b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), + dtype=a.dtype, device=a.device) + #print(f"b_a shape {b_a.shape}") + + #experts_per_token = torch.zeros(a.shape[0], dtype=torch.int, device=a.device) + + for i in range(topk_ids.shape[0]): + for j in range(topk_ids.shape[1]): + expert_id = topk_ids[i, j] + #idx = experts_per_token[i] + b_a[expert_id, j:j+1, :] = a[i, :] + #experts_per_token[i] = experts_per_token[i] + 1 + + return b_a, tokens_per_expert + + +def unbatch_output(b_out, topk_ids, K): + num_tokens, topk = topk_ids.shape + + #print(f"b_out = {b_out.shape} M={num_tokens}, K={K}, topk={topk}") + num_experts = b_out.shape[0] + out = torch.zeros((num_tokens, topk, K), dtype=b_out.dtype, device=b_out.device) + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + #print(f"b_out[0] = {b_out[0].shape}") + for i in range(expert_ids.numel()): + expert_id = expert_ids[i] + idx = expert_counts[expert_id] + out[token, i:i+1, :] = b_out[expert_id, idx:idx+1, :] + idx = idx + 1 + expert_counts[expert_id] = idx + + return out + + +def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids): + assert a.dim() == 3 + #print(f"A = {a.shape} {a[0, :, :].shape}") + num_tokens, topk = topk_ids.shape + _, max_num_tokens, K = a.shape + num_experts = w1.shape[0] + out = torch.zeros((num_experts, max_num_tokens, w2.shape[1]), dtype=a.dtype, device=a.device) + for expert in range(num_experts): + num = tokens_per_expert[expert] + if num > 0: + #out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + + out = unbatch_output(out, topk_ids, w2.shape[1]) + + return (out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1) + + +def torch_moe2(a, w1, w2, topk_weight, topk_ids): + M, K = a.shape + topk = topk_ids.shape[1] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + num_experts = w1.shape[0] + for i in range(num_experts): + mask = (topk_ids == i).view(-1) + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + score = torch.randn((m, e), device="cuda", dtype=dtype) + e_map = None + + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids = fused_topk(a, score, topk, False) + + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + + b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) + + if True: + triton_output = torch_batched_moe(b_a, + w1, + w2, + tokens_per_expert, + topk_weight, + topk_ids) + else: + triton_output = fused_experts(a, # b_a + w1, + w2, + topk_weight, + topk_ids, + global_num_experts=e) + + #torch.testing.assert_close(triton_b_output, torch_b_output, atol=2e-2, rtol=0) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + @pytest.mark.parametrize("m", [1, 32, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 1024]) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 59ca9899ba0f..ade7b5183ddf 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -915,16 +915,31 @@ def init_distributed_environment( "world group already initialized with a different world size") +PPLX_DID_INIT: bool = False + @run_once def pplx_init(rank, world_size): - print(f"PPLX_INIT {rank} {world_size}") - uid = nvshmem_get_unique_id( - ) if rank == 0 else nvshmem_alloc_empty_unique_id() - uid_gpu = uid.cuda() - get_world_group().broadcast(uid_gpu, src=0) - print(f"PPLX_INIT UID={uid_gpu}") - uid = uid_gpu.to(device='cpu') - nvshmem_init(uid, rank, world_size) + if world_size > 1: + try: + global PPLX_DID_INIT + print(f"PPLX_INIT {rank} {world_size}") + uid = nvshmem_get_unique_id( + ) if rank == 0 else nvshmem_alloc_empty_unique_id() + uid_gpu = uid.cuda() + get_world_group().broadcast(uid_gpu, src=0) + print(f"PPLX_INIT UID={uid_gpu}") + uid = uid_gpu.to(device='cpu') + nvshmem_init(uid, rank, world_size) + PPLX_DID_INIT = True + except Exception as ex: + logger.error("Failed to initialize nvshmem for pplx: %s", ex) + + +@run_once +def pplx_finalize(): + global PPLX_DID_INIT + if PPLX_DID_INIT: + nvshmem_finalize() def initialize_model_parallel( @@ -1099,7 +1114,7 @@ def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP - nvshmem_finalize() + pplx_finalize() if _TP: _TP.destroy() From 1d98c322117a9db757277df3a855809f4fc22e65 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 15 Apr 2025 02:13:33 +0000 Subject: [PATCH 135/171] simple test Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 58 +++++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 2ce0fa2b92aa..5d0cc91bb5b1 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -116,9 +116,12 @@ def batch_by_experts( assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a.shape[0] + num_tokens = a.shape[0] + topk = topk_ids.shape[1] + tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, device=a.device) - for i in range(topk_ids.shape[0]): - for j in range(topk_ids.shape[1]): + for i in range(num_tokens): + for j in range(topk): expert_id = topk_ids[i, j] tokens_per_expert[expert_id] = tokens_per_expert[expert_id] + 1 @@ -128,34 +131,41 @@ def batch_by_experts( dtype=a.dtype, device=a.device) #print(f"b_a shape {b_a.shape}") - #experts_per_token = torch.zeros(a.shape[0], dtype=torch.int, device=a.device) + experts_per_token = torch.zeros(num_experts, dtype=torch.int, device=a.device) - for i in range(topk_ids.shape[0]): - for j in range(topk_ids.shape[1]): - expert_id = topk_ids[i, j] - #idx = experts_per_token[i] - b_a[expert_id, j:j+1, :] = a[i, :] - #experts_per_token[i] = experts_per_token[i] + 1 + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = experts_per_token[expert_id] + b_a[expert_id, idx:idx+1, :] = a[token, :] + experts_per_token[expert_id] = experts_per_token[expert_id] + 1 + + if False: + print(f"topk_ids = {topk_ids}") + print(f"tokens_per_expert = {tokens_per_expert}") + print(f"experts_per_token = {experts_per_token}") return b_a, tokens_per_expert -def unbatch_output(b_out, topk_ids, K): +def unbatch_output(b_out, topk_weight, topk_ids, K): num_tokens, topk = topk_ids.shape #print(f"b_out = {b_out.shape} M={num_tokens}, K={K}, topk={topk}") num_experts = b_out.shape[0] - out = torch.zeros((num_tokens, topk, K), dtype=b_out.dtype, device=b_out.device) + topk = topk_ids.shape[1] + out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) + experts = torch.arange(0, num_experts, dtype=torch.int, device=b_out.device) for token in range(num_tokens): expert_ids = topk_ids[token] #print(f"b_out[0] = {b_out[0].shape}") for i in range(expert_ids.numel()): expert_id = expert_ids[i] idx = expert_counts[expert_id] - out[token, i:i+1, :] = b_out[expert_id, idx:idx+1, :] - idx = idx + 1 - expert_counts[expert_id] = idx + #print(f"out = {out[token, :].shape}, b_out = {b_out[expert_id, idx:idx+1, :].shape}, topk_w = {topk_weight[token, i]}") + out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 return out @@ -173,9 +183,9 @@ def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids): #out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) - out = unbatch_output(out, topk_ids, w2.shape[1]) + out = unbatch_output(out, topk_weight, topk_ids, K) - return (out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1) + return out #(out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1) def torch_moe2(a, w1, w2, topk_weight, topk_ids): @@ -200,6 +210,12 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +#@pytest.mark.parametrize("m", [33]) +#@pytest.mark.parametrize("n", [128]) +#@pytest.mark.parametrize("k", [128]) +#@pytest.mark.parametrize("e", [8]) +#@pytest.mark.parametrize("topk", [2]) +#@pytest.mark.parametrize("dtype", [torch.float16]) def test_fused_moe_batched_experts( m: int, n: int, @@ -208,12 +224,13 @@ def test_fused_moe_batched_experts( topk: int, dtype: torch.dtype, ): + current_platform.seed_everything(7) + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) - e_map = None vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): @@ -238,6 +255,13 @@ def test_fused_moe_batched_experts( topk_ids, global_num_experts=e) + if False: + torch.set_printoptions(profile="full") + print("BASELINE") + print(torch_output) + print("OUTPUT") + print(triton_output) + #torch.testing.assert_close(triton_b_output, torch_b_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) From 0dfd27ed3fd9929c2f731052a998122a62ecfe12 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 15 Apr 2025 02:16:23 +0000 Subject: [PATCH 136/171] cleanup Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 32 +++++++------------------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 5d0cc91bb5b1..e4d043a7007a 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -112,7 +112,6 @@ def batch_by_experts( topk_ids: torch.Tensor, num_experts: int ) -> torch.Tensor: - #print(topk_ids.shape, topk_ids) assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a.shape[0] @@ -125,25 +124,19 @@ def batch_by_experts( expert_id = topk_ids[i, j] tokens_per_expert[expert_id] = tokens_per_expert[expert_id] + 1 - #print(f"token_per_expert {tokens_per_expert.max()}") max_num_tokens = tokens_per_expert.max() b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), dtype=a.dtype, device=a.device) #print(f"b_a shape {b_a.shape}") - experts_per_token = torch.zeros(num_experts, dtype=torch.int, device=a.device) + token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) for token in range(num_tokens): for j in range(topk): expert_id = topk_ids[token, j] - idx = experts_per_token[expert_id] + idx = token_counts[expert_id] b_a[expert_id, idx:idx+1, :] = a[token, :] - experts_per_token[expert_id] = experts_per_token[expert_id] + 1 - - if False: - print(f"topk_ids = {topk_ids}") - print(f"tokens_per_expert = {tokens_per_expert}") - print(f"experts_per_token = {experts_per_token}") + token_counts[expert_id] = token_counts[expert_id] + 1 return b_a, tokens_per_expert @@ -151,7 +144,6 @@ def batch_by_experts( def unbatch_output(b_out, topk_weight, topk_ids, K): num_tokens, topk = topk_ids.shape - #print(f"b_out = {b_out.shape} M={num_tokens}, K={K}, topk={topk}") num_experts = b_out.shape[0] topk = topk_ids.shape[1] out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) @@ -159,11 +151,9 @@ def unbatch_output(b_out, topk_weight, topk_ids, K): experts = torch.arange(0, num_experts, dtype=torch.int, device=b_out.device) for token in range(num_tokens): expert_ids = topk_ids[token] - #print(f"b_out[0] = {b_out[0].shape}") for i in range(expert_ids.numel()): expert_id = expert_ids[i] idx = expert_counts[expert_id] - #print(f"out = {out[token, :].shape}, b_out = {b_out[expert_id, idx:idx+1, :].shape}, topk_w = {topk_weight[token, i]}") out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i] expert_counts[expert_id] = expert_counts[expert_id] + 1 @@ -172,7 +162,6 @@ def unbatch_output(b_out, topk_weight, topk_ids, K): def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids): assert a.dim() == 3 - #print(f"A = {a.shape} {a[0, :, :].shape}") num_tokens, topk = topk_ids.shape _, max_num_tokens, K = a.shape num_experts = w1.shape[0] @@ -180,12 +169,12 @@ def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids): for expert in range(num_experts): num = tokens_per_expert[expert] if num > 0: - #out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) - out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + #out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) out = unbatch_output(out, topk_weight, topk_ids, K) - return out #(out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1) + return out def torch_moe2(a, w1, w2, topk_weight, topk_ids): @@ -210,12 +199,6 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -#@pytest.mark.parametrize("m", [33]) -#@pytest.mark.parametrize("n", [128]) -#@pytest.mark.parametrize("k", [128]) -#@pytest.mark.parametrize("e", [8]) -#@pytest.mark.parametrize("topk", [2]) -#@pytest.mark.parametrize("dtype", [torch.float16]) def test_fused_moe_batched_experts( m: int, n: int, @@ -248,7 +231,7 @@ def test_fused_moe_batched_experts( topk_weight, topk_ids) else: - triton_output = fused_experts(a, # b_a + triton_output = fused_experts(b_a, w1, w2, topk_weight, @@ -262,7 +245,6 @@ def test_fused_moe_batched_experts( print("OUTPUT") print(triton_output) - #torch.testing.assert_close(triton_b_output, torch_b_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) From f6acee68b03a50560b4dd301d3b5eaebce6db63c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 15 Apr 2025 15:01:31 +0000 Subject: [PATCH 137/171] test pplx w/naive implementation Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 21 +++--- .../layers/fused_moe/fused_moe.py | 66 +++++++++++++++++++ vllm/model_executor/layers/fused_moe/layer.py | 21 +++--- .../layers/fused_moe/triton_deep_gemm_moe.py | 17 +++-- 4 files changed, 99 insertions(+), 26 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index e4d043a7007a..43622e17f50d 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -118,11 +118,7 @@ def batch_by_experts( num_tokens = a.shape[0] topk = topk_ids.shape[1] - tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, device=a.device) - for i in range(num_tokens): - for j in range(topk): - expert_id = topk_ids[i, j] - tokens_per_expert[expert_id] = tokens_per_expert[expert_id] + 1 + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) max_num_tokens = tokens_per_expert.max() b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), @@ -170,7 +166,6 @@ def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids): num = tokens_per_expert[expert] if num > 0: out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) - #out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) out = unbatch_output(out, topk_weight, topk_ids, K) @@ -231,12 +226,14 @@ def test_fused_moe_batched_experts( topk_weight, topk_ids) else: - triton_output = fused_experts(b_a, - w1, - w2, - topk_weight, - topk_ids, - global_num_experts=e) + triton_output = fused_batched_experts( + b_a, + w1, + w2, + topk_weight, + topk_ids, + global_num_experts=e + ) if False: torch.set_printoptions(profile="full") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2b4b725065fa..49af539739d6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1754,6 +1754,72 @@ def apply( return intermediate_cache3 +class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + block_shape: Optional[List[int]] = None, + block_m: Optional[int] = None, + ): + super().__init__() + assert not use_fp8_w8a8 + assert not use_int4_w4a16 + assert not use_int8_w8a16 + assert block_shape is None + assert block_m is None + + def workspace_shapes( + self, + a_dtype: torch.dtype, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + a: torch.Tensor, + ) -> Tuple[int, int, torch.dtype]: + max_num_tokens = a.shape[1] + workspace13 = num_experts * max_num_tokens * K + workspace2 = M * topk * N * num_experts + return (workspace13, workspace2, a_dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: + from vllm.model_executor.layers.activation import SiluAndMul + assert hidden_states.dim() == 3 + num_tokens, topk = topk_ids.shape + _, max_num_tokens, K = hidden_states.shape + num_experts = w1.shape[0] + out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) + #tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + for expert in range(num_experts): + num = 1 #tokens_per_expert[expert] + if num > 0: + #out[expert, :num, :] = SiluAndMul(hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + out[expert, :, :] = SiluAndMul()(hidden_states[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + + return out + + def modular_triton_fused_moe( use_fp8_w8a8: bool, use_int8_w8a8: bool, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 15ed0ee30e83..2ae31fc3e39d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -29,7 +29,7 @@ if current_platform.is_cuda_alike(): from .dispatch_combine import StandardDispatchCombine - from .fused_moe import TritonExperts, fused_experts + from .fused_moe import TritonExperts, BatchedExperts, fused_experts from .modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine from .pplx_dispatch_combine import PplxDispatchCombine else: @@ -243,13 +243,16 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) #print(f"block_m = {block_m}") - experts = TritonExperts( - use_fp8_w8a8 = False, - use_int8_w8a16 = False, - use_int4_w4a16 = False, - block_shape = None, - block_m = None, #block_m, - ) + if False: + experts = TritonExperts( + use_fp8_w8a8 = False, + use_int8_w8a16 = False, + use_int4_w4a16 = False, + block_shape = None, + block_m = None, #block_m, + ) + else: + experts = BatchedExperts() self.fused_experts = FusedMoEModularKernel( dispatch_combine, @@ -1029,7 +1032,7 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - print(f"loop {chunk_start}:{chunk_end}") + #print(f"loop {chunk_start}:{chunk_end}") cu_tokens_across_dp_this_iter = torch.cumsum( num_tokens_remaining_across_dp.clamp( diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index f3a13e44296d..21cba37478e9 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -35,16 +35,23 @@ def __init__( self.allow_deep_gemm = allow_deep_gemm self.use_fp8_w8a8 = use_fp8_w8a8 - def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, - topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes( + self, + a_dtype: torch.dtype, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + a: torch.Tensor, + ) -> Tuple[int, int, torch.dtype]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): - return self.deep_gemm_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts) + return self.deep_gemm_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts, a) else: - return self.triton_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts) + return self.triton_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts, a) def apply( self, From c69354d41fb537839b209b8ab7df7040395b00f2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 15 Apr 2025 15:02:05 +0000 Subject: [PATCH 138/171] test pplx w/naive implementation Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 43622e17f50d..ccdcb583d660 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -14,8 +14,9 @@ from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, torch_moe_single) from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe import fused_moe, fused_experts -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( From 4971b43bb85def6142ac1db1fbaca561a4386d78 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 15 Apr 2025 17:13:17 +0000 Subject: [PATCH 139/171] hack fix for chunking loop Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 33 ++++++++++--------- .../layers/fused_moe/fused_moe.py | 10 +++--- vllm/model_executor/layers/fused_moe/layer.py | 22 +++++++++++-- .../layers/fused_moe/modular_kernel.py | 6 ++-- .../layers/fused_moe/pplx_dispatch_combine.py | 4 +-- 5 files changed, 48 insertions(+), 27 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index ccdcb583d660..3673dde94d87 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -108,7 +108,7 @@ def test_fused_moe( rtol=0) -def batch_by_experts( +def torch_dispatch( a: torch.Tensor, topk_ids: torch.Tensor, num_experts: int @@ -138,14 +138,14 @@ def batch_by_experts( return b_a, tokens_per_expert -def unbatch_output(b_out, topk_weight, topk_ids, K): +def torch_combine(b_out, topk_weight, topk_ids): num_tokens, topk = topk_ids.shape num_experts = b_out.shape[0] topk = topk_ids.shape[1] + K = b_out.shape[-1] out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) - experts = torch.arange(0, num_experts, dtype=torch.int, device=b_out.device) for token in range(num_tokens): expert_ids = topk_ids[token] for i in range(expert_ids.numel()): @@ -157,22 +157,25 @@ def unbatch_output(b_out, topk_weight, topk_ids, K): return out -def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids): - assert a.dim() == 3 - num_tokens, topk = topk_ids.shape - _, max_num_tokens, K = a.shape +def torch_batched_moe(a, w1, w2, topk_weight, topk_ids): num_experts = w1.shape[0] - out = torch.zeros((num_experts, max_num_tokens, w2.shape[1]), dtype=a.dtype, device=a.device) + b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts) + assert b_a.dim() == 3 + num_tokens, topk = topk_ids.shape + _, max_num_tokens, K = b_a.shape + assert num_experts == b_a.shape[0] and K == w2.shape[1] + out = torch.zeros((num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device) + tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device) for expert in range(num_experts): num = tokens_per_expert[expert] if num > 0: - out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + torch.ops._C.silu_and_mul(tmp[:num], b_a[expert,:num,:] @ w1[expert].transpose(0, 1)) + out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) - out = unbatch_output(out, topk_weight, topk_ids, K) - - return out + return torch_combine(out, topk_weight, topk_ids) +# TODO: same as torch_moe but with fused_topk factored out. def torch_moe2(a, w1, w2, topk_weight, topk_ids): M, K = a.shape topk = topk_ids.shape[1] @@ -217,16 +220,14 @@ def test_fused_moe_batched_experts( torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) - if True: - triton_output = torch_batched_moe(b_a, + triton_output = torch_batched_moe(a, w1, w2, - tokens_per_expert, topk_weight, topk_ids) else: + b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) triton_output = fused_batched_experts( b_a, w1, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 49af539739d6..aea3d5edcd51 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1783,7 +1783,7 @@ def workspace_shapes( ) -> Tuple[int, int, torch.dtype]: max_num_tokens = a.shape[1] workspace13 = num_experts * max_num_tokens * K - workspace2 = M * topk * N * num_experts + workspace2 = max_num_tokens * (N // 2) return (workspace13, workspace2, a_dtype) def apply( @@ -1810,12 +1810,14 @@ def apply( _, max_num_tokens, K = hidden_states.shape num_experts = w1.shape[0] out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) + # causes deadlock #tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) for expert in range(num_experts): - num = 1 #tokens_per_expert[expert] + num = max_num_tokens #tokens_per_expert[expert] if num > 0: - #out[expert, :num, :] = SiluAndMul(hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) - out[expert, :, :] = SiluAndMul()(hidden_states[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) + torch.ops._C.silu_and_mul(tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) + out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) return out diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2ae31fc3e39d..075510191141 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1028,11 +1028,15 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_hidden_states.shape[0]) full_final_hidden_states = torch.empty_like(full_hidden_states) - for _ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): + #print(f"ORIGINAL SHAPE {full_hidden_states.shape}") + + #print(f"moe_dp_chunk_size_per_rank = {moe_dp_chunk_size_per_rank}") + + for iter in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - #print(f"loop {chunk_start}:{chunk_end}") + #print(f"loop {iter}: {chunk_start}:{chunk_end}, {hidden_states.shape}") cu_tokens_across_dp_this_iter = torch.cumsum( num_tokens_remaining_across_dp.clamp( @@ -1062,6 +1066,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, activation=self.activation, ) + #print(f"final1 = {final_hidden_states.shape}") + if self.dp_size > 1: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[ self.dp_rank - 1] @@ -1071,19 +1077,31 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] + #print(f"final2 (AR) = {final_hidden_states.shape}") + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) + #print(f"final3 (AR) = {final_hidden_states.shape}") + full_final_hidden_states[chunk_start:chunk_end, :].copy_( final_hidden_states) + #print(f"full final = {full_final_hidden_states.shape}") + # Update bounds num_tokens_remaining_across_dp = torch.clamp( num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, min=0) + #print(f"num remaining = {num_tokens_remaining_across_dp}") + + # HACK FIX + if num_tokens_remaining_across_dp.sum() == 0: + break + def update_chunk_bound(x: int): return min(x + moe_dp_chunk_size_per_rank, full_hidden_states.shape[0]) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 76ece80ba474..35f8b8292771 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -312,8 +312,8 @@ def forward( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ - from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank) - print(f"START {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") + #from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank) + #print(f"START {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") a1 = hidden_states E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) @@ -364,6 +364,6 @@ def forward( self.dispatch_combine.combine(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) - print(f"DONE {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") + #print(f"DONE {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") return output diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 223b5d3d2aae..9377d6d63317 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -133,7 +133,7 @@ def combine( #device = get_dp_group().device #assert fused_expert_output.device == device - print(f"COMBINE START {self.rank}") + #print(f"COMBINE START {self.rank}") # This argument is optional #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens @@ -154,4 +154,4 @@ def combine( expert_y=fused_expert_output, bound_m=bound_m) - print(f"COMBINE END {self.rank}") + #print(f"COMBINE END {self.rank}") From fedb2d2c113342ef9afb1a22976d890c136b0198 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 16 Apr 2025 20:34:49 +0000 Subject: [PATCH 140/171] wip. add pplx unit test Signed-off-by: Bill Nell --- examples/offline_inference/data_parallel.py | 2 +- tests/kernels/moe/test_moe.py | 2 - tests/kernels/test_pplx_moe.py | 432 ++++++++++++++++++ .../layers/fused_moe/fused_moe.py | 93 +++- vllm/model_executor/layers/fused_moe/layer.py | 36 +- .../layers/fused_moe/modular_kernel.py | 2 +- .../layers/fused_moe/pplx_dispatch_combine.py | 2 +- .../layers/fused_moe/triton_deep_gemm_moe.py | 3 + 8 files changed, 550 insertions(+), 22 deletions(-) create mode 100644 tests/kernels/test_pplx_moe.py diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 965915beaf58..1c0701051890 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -160,7 +160,7 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, procs.append(proc) exit_code = 0 for proc in procs: - proc.join(timeout=300) + proc.join(timeout=3000) if proc.exitcode is None: print(f"Killing process {proc.pid} that " f"didn't stop within 5 minutes.") diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 3673dde94d87..1807e1b22be7 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -140,9 +140,7 @@ def torch_dispatch( def torch_combine(b_out, topk_weight, topk_ids): num_tokens, topk = topk_ids.shape - num_experts = b_out.shape[0] - topk = topk_ids.shape[1] K = b_out.shape[-1] out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py new file mode 100644 index 000000000000..b3b8817c69ce --- /dev/null +++ b/tests/kernels/test_pplx_moe.py @@ -0,0 +1,432 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the MOE layers. + +Run `pytest tests/kernels/test_pplx_moe.py`. +""" +import dataclasses +import os +import pytest +import torch +from torch.nn import Parameter +from torch.nn import functional as F +from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] +from typing import Callable, Concatenate, ParamSpec + +from pplx_kernels import AllToAll +from pplx_kernels.nvshmem import ( + nvshmem_alloc_empty_unique_id, + nvshmem_finalize, + nvshmem_get_unique_id, + nvshmem_init, +) + +import vllm.model_executor.layers.fused_moe # noqa +from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, + torch_moe, torch_moe_single) +from vllm import _custom_ops as ops +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe import fused_moe +#from vllm.model_executor.layers.fused_moe.fused_batched_moe import fused_batched_experts +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( + fused_moe as iterative_moe) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + marlin_quantize) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + quantize_weights) +from vllm.model_executor.models.mixtral import MixtralMoE +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types +from vllm.model_executor.layers.activation import SiluAndMul + +from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts, BatchedDispatchCombine, BatchedExperts, fused_experts +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine +from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import PplxDispatchCombine + +NUM_EXPERTS = [8, 64] +EP_SIZE = [1, 4] +TOP_KS = [2, 6] + +P = ParamSpec("P") + +require_multi_node = pytest.mark.skipif( + "MASTER_ADDR" not in os.environ, + reason="Requires multi-node environment", +) + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + *args, + **kwargs, + ) + except Exception: + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + "tcp://localhost:29500", + worker, + ) + + args, + nprocs=world_size, + join=True, + ) + + +def parallel_launch_from_env( + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + """ + Launches a worker function in parallel across all processes in the current + environment. The environment must have the following variables set: + - WORLD_SIZE: The total number of processes. + - WORLD_LOCAL_SIZE: The number of processes on the current node. + - NODE_RANK: The rank of the current + - MASTER_ADDR: The address of the master process. + - MASTER_PORT: The port of the master process. + """ + assert not kwargs + world_size = int(os.environ["WORLD_SIZE"]) + world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) + node_rank = int(os.environ["NODE_RANK"]) + assert "MASTER_ADDR" in os.environ + assert "MASTER_PORT" in os.environ + spawn( + _worker_parallel_launch, + args=( + world_size, + world_local_size, + node_rank, + "env://", + worker, + ) + + args, + nprocs=world_local_size, + join=True, + ) + + +def torch_dispatch( + a: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int +) -> torch.Tensor: + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a.shape[0] + + num_tokens = a.shape[0] + topk = topk_ids.shape[1] + + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + + max_num_tokens = tokens_per_expert.max() + b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), + dtype=a.dtype, device=a.device) + #print(f"b_a shape {b_a.shape}") + + token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) + + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = token_counts[expert_id] + b_a[expert_id, idx:idx+1, :] = a[token, :] + token_counts[expert_id] = token_counts[expert_id] + 1 + + return b_a, tokens_per_expert + + +def torch_combine(b_out, topk_weight, topk_ids): + num_tokens, topk = topk_ids.shape + num_experts = b_out.shape[0] + K = b_out.shape[-1] + out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + for i in range(expert_ids.numel()): + expert_id = expert_ids[i] + idx = expert_counts[expert_id] + out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + return out + + +def torch_batched_moe(a, w1, w2, topk_weight, topk_ids): + num_experts = w1.shape[0] + b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts) + assert b_a.dim() == 3 + num_tokens, topk = topk_ids.shape + _, max_num_tokens, K = b_a.shape + assert num_experts == b_a.shape[0] and K == w2.shape[1] + out = torch.zeros((num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device) + tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device) + for expert in range(num_experts): + num = tokens_per_expert[expert] + if num > 0: + torch.ops._C.silu_and_mul(tmp[:num], b_a[expert,:num,:] @ w1[expert].transpose(0, 1)) + out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) + + return torch_combine(out, topk_weight, topk_ids) + + +# TODO: same as torch_moe but with fused_topk factored out. +def torch_moe2(a, w1, w2, topk_weight, topk_ids): + M, K = a.shape + topk = topk_ids.shape[1] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + num_experts = w1.shape[0] + for i in range(num_experts): + mask = (topk_ids == i).view(-1) + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids = fused_topk(a, score, topk, False) + + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + + if True: + triton_output = torch_batched_moe(a, + w1, + w2, + topk_weight, + topk_ids) + else: + b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) + triton_output = fused_batched_experts( + b_a, + w1, + w2, + topk_weight, + topk_ids, + global_num_experts=e + ) + + if False: + torch.set_printoptions(profile="full") + print("BASELINE") + print(torch_output) + print("OUTPUT") + print(triton_output) + + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + + +def torch_pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): + hidden_dim = a.shape[-1] + num_experts = w1.shape[0] + num_local_experts = num_experts // pgi.world_size + block_size = 128 + topk = topk_ids.shape[1] + + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + max_num_tokens = tokens_per_expert.max() + rank = pgi.rank + + ata = AllToAll( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=pgi.world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=( + 0 + if a.dtype.itemsize != 1 + else ( + (hidden_dim + block_size - 1) + // block_size + * torch.float32.itemsize + ) + ), + ) + + dispatch_combine = PplxDispatchCombine( + ata, + max_num_tokens, + pgi.world_size, + dp_size, + rank, + a.dtype, + ) + + experts = BatchedExperts() + + fused_experts = FusedMoEModularKernel( + dispatch_combine, + experts, + ) + + out = fused_experts( + a, + w1, + w2, + topk_weight, + topk_ids + ) + + ata.destroy() + + return out + + + +def _pplx_moe( + pgi: ProcessGroupInfo, + dp_size: int, + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids = fused_topk(a, score, topk, False) + + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + + triton_output = torch_pplx_moe(pgi, + a, + w1, + w2, + topk_weight, + topk_ids) + + if False: + torch.set_printoptions(profile="full") + print("BASELINE") + print(torch_output) + print("OUTPUT") + print(triton_output) + + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_pplx_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + world_size = 4 + dp_size = 2 + parallel_launch( + world_size, _pplx_moe, dp_size, m, n, k, e, topk, dtype + ) + diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index aea3d5edcd51..f3105c6da63e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1754,6 +1754,82 @@ def apply( return intermediate_cache3 +class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + def __init__(self, + world_size: int, + rank: int): + super().__init__() + self.world_size = world_size + self.rank = rank + + def dispatch( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a1.shape[0] + + num_tokens = a1.shape[0] + topk = topk_ids.shape[1] + + #assert num_experts % self.world_size == 0 + #num_local_experts = num_experts // self.world_size + + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + max_num_tokens = tokens_per_expert.max() + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=a1.device) + + b_a1 = torch.zeros((num_experts, max_num_tokens, a1.shape[1]), + dtype=a1.dtype, device=a1.device) + + #print(f"START DISPATCH {hex(id(self))}") + + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = expert_counts[expert_id] + b_a1[expert_id, idx:idx+1, :] = a1[token, :] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + #print(f"END DISPATCH {hex(id(self))}: tokens_per_expert {(tokens_per_expert > 0).nonzero().view(-1)}") + + return b_a1, a1_scale, tokens_per_expert + + def combine( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> None: + if False: + print(f"topk_ids {topk_ids.shape}") + print(f"fused_expert_output {fused_expert_output.shape}") + print(f"output {output.shape}") + print(f"counts {self.expert_counts.shape}") + + #print(f"START COMBINE {hex(id(self))}") + + num_tokens, topk = topk_ids.shape + num_experts, _, K = fused_expert_output.shape + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=fused_expert_output.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + for i in range(topk_ids.shape[1]): + expert_id = expert_ids[i] + if expert_id < num_experts: + idx = expert_counts[expert_id] + output[token, :] = output[token, :] + fused_expert_output[expert_id, idx:idx+1, :] * topk_weights[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + #print(f"END COMBINE {hex(id(self))}") + + class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -1803,21 +1879,28 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: - from vllm.model_executor.layers.activation import SiluAndMul + #print("START EXPERTS") assert hidden_states.dim() == 3 + assert expert_num_tokens is not None num_tokens, topk = topk_ids.shape _, max_num_tokens, K = hidden_states.shape num_experts = w1.shape[0] out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) - # causes deadlock - #tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) for expert in range(num_experts): - num = max_num_tokens #tokens_per_expert[expert] + num = expert_num_tokens[expert] if num > 0: tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) - torch.ops._C.silu_and_mul(tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) + self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) + # fill remainder with 0??? + #out[expert, num:, :].fill_(0) + else: + #out[expert, :, :].fill_(0) # ?? + pass + + #print("END EXPERTS") return out diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 075510191141..09b0f5a7e114 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -29,7 +29,7 @@ if current_platform.is_cuda_alike(): from .dispatch_combine import StandardDispatchCombine - from .fused_moe import TritonExperts, BatchedExperts, fused_experts + from .fused_moe import TritonExperts, BatchedDispatchCombine, BatchedExperts, fused_experts from .modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine from .pplx_dispatch_combine import PplxDispatchCombine else: @@ -116,7 +116,7 @@ def get_or_create(self, **kwargs): with self._lock: instance = self._cache.get(key) - if True or instance is None: + if instance is None: instance = pplx.AllToAll(**kwargs) self._cache[key] = instance return instance @@ -240,10 +240,15 @@ def apply( # Maybe extra args def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: + assert self.fused_experts == fused_experts + block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) #print(f"block_m = {block_m}") - if False: + if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): + logger.info("BatchedExperts") + experts = BatchedExperts() + else: experts = TritonExperts( use_fp8_w8a8 = False, use_int8_w8a16 = False, @@ -251,8 +256,6 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine block_shape = None, block_m = None, #block_m, ) - else: - experts = BatchedExperts() self.fused_experts = FusedMoEModularKernel( dispatch_combine, @@ -609,6 +612,7 @@ def __init__( # TODO: move to method? if self.dp_size > 1: + logger.info("using pplx dispatch") max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size world_size = moe.ep_size dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. @@ -652,15 +656,22 @@ def __init__( rank, # just for debugging moe.in_dtype, ) - - success = self.quant_method.set_dispatch_combine(dispatch_combine) - if not success: - logger.warning("DP+EP not supported for %s.", type(self.quant_method)) - else: + elif False: + logger.info("using standard dispatch") dispatch_combine = StandardDispatchCombine( moe.in_dtype, quant_config.weight_block_size if quant_config is not None else None, ) + else: + logger.info("using batched dispatch") + dispatch_combine = BatchedDispatchCombine( + moe.ep_size, + moe.ep_rank, + ) + + success = self.quant_method.set_dispatch_combine(dispatch_combine) + if not success: + logger.warning("DP+EP not supported for %s.", type(self.quant_method)) self.apply_router_weight_on_input = apply_router_weight_on_input moe_quant_params = { @@ -1029,7 +1040,6 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_final_hidden_states = torch.empty_like(full_hidden_states) #print(f"ORIGINAL SHAPE {full_hidden_states.shape}") - #print(f"moe_dp_chunk_size_per_rank = {moe_dp_chunk_size_per_rank}") for iter in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): @@ -1089,7 +1099,7 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_final_hidden_states[chunk_start:chunk_end, :].copy_( final_hidden_states) - #print(f"full final = {full_final_hidden_states.shape}") + #print(f"partial final = {full_final_hidden_states.shape}") # Update bounds num_tokens_remaining_across_dp = torch.clamp( @@ -1109,6 +1119,8 @@ def update_chunk_bound(x: int): chunk_start = update_chunk_bound(chunk_start) chunk_end = update_chunk_bound(chunk_end) + #print(f"full final shape {full_final_hidden_states.shape}") + return full_final_hidden_states def forward_impl(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 35f8b8292771..96ecf5990a66 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -323,7 +323,7 @@ def forward( if global_num_experts == -1: global_num_experts = E - output = a1 if inplace else torch.empty_like(a1) + output = a1 if inplace else torch.zeros_like(a1) workspace13_shape, workspace2_shape, workspace_dtype = ( self.fused_experts.workspace_shapes(a1, M, N, K, top_k, diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 9377d6d63317..a36c825d9e75 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -75,7 +75,7 @@ def dispatch( dtype=torch.int32, device=device, ) - expert_num_tokens.fill_(-1) # debugging remove + #expert_num_tokens.fill_(-1) # debugging remove num_dp = self.world_size // self.dp_size expert_x = torch.empty( diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 21cba37478e9..be28d620f47d 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -70,6 +70,7 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: N = w1.shape[1] if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 @@ -90,6 +91,7 @@ def apply( a2_scale, workspace13, workspace2, + expert_num_tokens, ) else: return self.triton_expert( @@ -108,4 +110,5 @@ def apply( a2_scale, workspace13, workspace2, + expert_num_tokens, ) From 46d09b7379b04733a340fda6e0169a1dd0bb7321 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 17 Apr 2025 00:10:05 +0000 Subject: [PATCH 141/171] work on unit test Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 323 +++++++++++++++--- .../layers/fused_moe/fused_moe.py | 3 +- .../layers/fused_moe/pplx_dispatch_combine.py | 17 +- 3 files changed, 286 insertions(+), 57 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index b3b8817c69ce..0156253d680e 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -7,6 +7,8 @@ import os import pytest import torch +import traceback + from torch.nn import Parameter from torch.nn import functional as F from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] @@ -38,6 +40,8 @@ from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.utils import round_up + from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts, BatchedDispatchCombine, BatchedExperts, fused_experts @@ -102,7 +106,9 @@ def _worker_parallel_launch( *args, **kwargs, ) - except Exception: + except Exception as ex: + print(ex) + traceback.print_exception(ex) raise finally: torch.distributed.destroy_process_group() @@ -247,13 +253,150 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) -@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 511, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_fused_moe_batched_experts( +# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +# @pytest.mark.parametrize("n", [128, 1024, 2048]) +# @pytest.mark.parametrize("k", [128, 511, 1024]) +# @pytest.mark.parametrize("e", NUM_EXPERTS) +# @pytest.mark.parametrize("topk", TOP_KS) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# def test_fused_moe_batched_experts( +# m: int, +# n: int, +# k: int, +# e: int, +# topk: int, +# dtype: torch.dtype, +# ): +# current_platform.seed_everything(7) + +# a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 +# w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 +# w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + +# score = torch.randn((m, e), device="cuda", dtype=dtype) + +# vllm_config = VllmConfig() +# with set_current_vllm_config(vllm_config): +# topk_weight, topk_ids = fused_topk(a, score, topk, False) + +# torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + +# if True: +# triton_output = torch_batched_moe(a, +# w1, +# w2, +# topk_weight, +# topk_ids) +# else: +# b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) +# triton_output = fused_batched_experts( +# b_a, +# w1, +# w2, +# topk_weight, +# topk_ids, +# global_num_experts=e +# ) + +# if False: +# torch.set_printoptions(profile="full") +# print("BASELINE") +# print(torch_output) +# print("OUTPUT") +# print(triton_output) + +# torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + +def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): + assert torch.cuda.current_device() == pgi.local_rank + + num_tokens, hidden_dim = a.shape + num_experts = w1.shape[0] + block_size = 128 + device = pgi.device + + max_num_tokens = num_tokens + print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") + rank = pgi.rank + + ata = AllToAll( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=pgi.world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=( + 0 + if a.dtype.itemsize != 1 + else ( + (hidden_dim + block_size - 1) + // block_size + * torch.float32.itemsize + ) + ), + ) + + dispatch_combine = PplxDispatchCombine( + ata, + max_num_tokens, + pgi.world_size, + dp_size, + rank, + a.dtype, + ) + + def chunk_by_rank(t, r): + num = t.shape[0] + assert num % pgi.world_size == 0, f"{num}, {pgi.world_size}" # for now + chunk = num // pgi.world_size + print(f"chunk {t.shape}, {pgi.world_size}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") + return t[(r * chunk):(r + 1)*chunk] + + a_chunk = chunk_by_rank(a, rank).to(device) + score_chunk = chunk_by_rank(scores, rank).to(device) + chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) + + #print(f"chunk_topk_ids = {chunk_topk_ids}") + + b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( + a_chunk, + None, + None, + chunk_topk_ids, + num_experts, # store at PplxDispatchCombine creation? + None + ) + torch.cuda.synchronize() # necessary? + + out = torch.full( + (max_num_tokens, hidden_dim), + torch.nan, + dtype=a.dtype, + device=device, + ) + + dispatch_combine.combine( + out, + b_a, + chunk_topk_weight, + chunk_topk_ids, + ) + torch.cuda.synchronize() + + ata.destroy() + + torch.distributed.barrier() + + return out[:num_tokens] + + +def _pplx_dispatch_combine( + pgi: ProcessGroupInfo, + dp_size: int, m: int, n: int, k: int, @@ -261,7 +404,9 @@ def test_fused_moe_batched_experts( topk: int, dtype: torch.dtype, ): - current_platform.seed_everything(7) + uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -269,49 +414,74 @@ def test_fused_moe_batched_experts( score = torch.randn((m, e), device="cuda", dtype=dtype) - vllm_config = VllmConfig() - with set_current_vllm_config(vllm_config): - topk_weight, topk_ids = fused_topk(a, score, topk, False) + topk_weight, topk_ids = fused_topk(a, score, topk, False) - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + print(f"a {a.shape}") + a_rep = torch.repeat_interleave(a, topk, dim=1) + print(f"a_rep {a_rep.shape}") + + torch_output = (a_rep.view(-1, topk, k) * topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) - if True: - triton_output = torch_batched_moe(a, + pplx_output = torch_pplx_dispatch_combine(pgi, + dp_size, + a, w1, w2, - topk_weight, - topk_ids) - else: - b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) - triton_output = fused_batched_experts( - b_a, - w1, - w2, - topk_weight, - topk_ids, - global_num_experts=e - ) + score, + topk) if False: torch.set_printoptions(profile="full") print("BASELINE") print(torch_output) print("OUTPUT") - print(triton_output) + print(pplx_output) + + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) +# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +# @pytest.mark.parametrize("n", [128, 1024, 2048]) +# @pytest.mark.parametrize("k", [128, 511, 1024]) +# @pytest.mark.parametrize("e", NUM_EXPERTS) +# @pytest.mark.parametrize("topk", TOP_KS) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("m", [128]) ##, 32]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128]) +@pytest.mark.parametrize("k", [128]) +@pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) +@pytest.mark.parametrize("topk", [2]) #TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_pplx_dispatch_combine( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + if False: + world_size = 4 + dp_size = 2 + else: + world_size = 2 + dp_size = 1 + parallel_launch( + world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, topk, dtype + ) -def torch_pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): +def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): hidden_dim = a.shape[-1] num_experts = w1.shape[0] num_local_experts = num_experts // pgi.world_size block_size = 128 - topk = topk_ids.shape[1] - tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - max_num_tokens = tokens_per_expert.max() + max_num_tokens = round_up(a.shape[0], 128) #tokens_per_expert.max() + print(f"max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}/{num_local_experts}") rank = pgi.rank ata = AllToAll( @@ -350,20 +520,60 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): experts, ) - out = fused_experts( - a, - w1, - w2, - topk_weight, - topk_ids - ) + def chunk_by_rank(t, r): + num = t.shape[0] + assert num % pgi.world_size == 0, f"{num}, {dp_size}" # for now + chunk = num // pgi.world_size + return t[(r * chunk):(r + 1)*chunk] + + a_chunk = chunk_by_rank(a, rank) + chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, chunk_by_rank(scores, rank), topk, False) + + print(f"chunk_topk_ids = {chunk_topk_ids}") + + # TODO: chunk up by rank + if False: + out = fused_experts( + a_chunk, + w1, # chunk? + w2, # chunk? + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_local_experts + ) + # reduce outputs? + else: + b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( + a_chunk, + None, + None, + chunk_topk_ids, + num_experts, + None + ) + torch.cuda.synchronize() + + out = torch.full( + (max_num_tokens, hidden_dim), + torch.nan, + dtype=a.dtype, + device=a.device, + ) + + dispatch_combine.combine( + out, + b_a, + chunk_topk_weight, + chunk_topk_ids, + ) + + torch.cuda.synchronize() ata.destroy() return out - def _pplx_moe( pgi: ProcessGroupInfo, dp_size: int, @@ -391,11 +601,12 @@ def _pplx_moe( torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) triton_output = torch_pplx_moe(pgi, + dp_size, a, w1, w2, - topk_weight, - topk_ids) + score, + topk) if False: torch.set_printoptions(profile="full") @@ -409,12 +620,18 @@ def _pplx_moe( nvshmem_finalize() -@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 511, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +# @pytest.mark.parametrize("n", [128, 1024, 2048]) +# @pytest.mark.parametrize("k", [128, 511, 1024]) +# @pytest.mark.parametrize("e", NUM_EXPERTS) +# @pytest.mark.parametrize("topk", TOP_KS) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("m", [128]) ##, 32]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128]) +@pytest.mark.parametrize("k", [128]) +@pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) +@pytest.mark.parametrize("topk", [2]) #TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_pplx_moe( m: int, n: int, @@ -424,8 +641,12 @@ def test_pplx_moe( dtype: torch.dtype, ): current_platform.seed_everything(7) - world_size = 4 - dp_size = 2 + if False: + world_size = 4 + dp_size = 2 + else: + world_size = 2 + dp_size = 1 parallel_launch( world_size, _pplx_moe, dp_size, m, n, k, e, topk, dtype ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f3105c6da63e..a22496b7d026 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1886,7 +1886,8 @@ def apply( assert expert_num_tokens is not None num_tokens, topk = topk_ids.shape _, max_num_tokens, K = hidden_states.shape - num_experts = w1.shape[0] + print(f"global_num_experts = {global_num_experts}") + num_experts = global_num_experts out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) for expert in range(num_experts): num = expert_num_tokens[expert] diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index a36c825d9e75..dd8fe4a36fba 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -75,15 +75,18 @@ def dispatch( dtype=torch.int32, device=device, ) - #expert_num_tokens.fill_(-1) # debugging remove + #expert_num_tokens.fill_(-1) # debugging, remove later num_dp = self.world_size // self.dp_size + print(f"GOT HERE A {self.rank}: {self.max_num_tokens} {num_dp} {hidden_dim}") expert_x = torch.empty( (num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]), dtype=a1q.dtype, device=device, ) - expert_x.fill_(torch.nan) # debugging remove + expert_x.fill_(torch.nan) # debugging, remove later + + print(f"GOT HERE B {self.rank}") expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: @@ -100,6 +103,8 @@ def dispatch( device=device, ) + print(f"GOT HERE C {self.rank}") + # This argument is optional, defaults to indices.shape[0] # This causes a deadlock???? #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens @@ -107,7 +112,9 @@ def dispatch( bound_m = None # TODO: optimize this? - indices = rank_topk_ids.to(dtype=torch.uint32).to(device) + indices = rank_topk_ids.to(dtype=torch.uint32) + + print(f"GOT HERE D {self.rank}") self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, @@ -133,7 +140,7 @@ def combine( #device = get_dp_group().device #assert fused_expert_output.device == device - #print(f"COMBINE START {self.rank}") + print(f"COMBINE START {self.rank}") # This argument is optional #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens @@ -154,4 +161,4 @@ def combine( expert_y=fused_expert_output, bound_m=bound_m) - #print(f"COMBINE END {self.rank}") + print(f"COMBINE END {self.rank}") From 7db006153b6841f30129be62ab9d9fc63337eaf7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 17 Apr 2025 03:45:09 +0000 Subject: [PATCH 142/171] dispatch/combine unit test Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 104 ++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 46 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index 0156253d680e..afb0b8858661 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -308,6 +308,14 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): # torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) +def chunk_by_rank(t, r, w): + num = t.shape[0] + assert num % w == 0, f"{num}, {w}" # for now + chunk = num // w + #print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") + return t[(r * chunk):(r + 1)*chunk] + + def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): assert torch.cuda.current_device() == pgi.local_rank @@ -315,10 +323,12 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): num_experts = w1.shape[0] block_size = 128 device = pgi.device + rank_num_tokens = num_tokens // pgi.world_size max_num_tokens = num_tokens - print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") + #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") rank = pgi.rank + world_size = pgi.world_size ata = AllToAll( max_num_tokens=max_num_tokens, @@ -342,22 +352,15 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): dispatch_combine = PplxDispatchCombine( ata, - max_num_tokens, + max_num_tokens, # // world_size? pgi.world_size, dp_size, rank, a.dtype, ) - def chunk_by_rank(t, r): - num = t.shape[0] - assert num % pgi.world_size == 0, f"{num}, {pgi.world_size}" # for now - chunk = num // pgi.world_size - print(f"chunk {t.shape}, {pgi.world_size}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") - return t[(r * chunk):(r + 1)*chunk] - - a_chunk = chunk_by_rank(a, rank).to(device) - score_chunk = chunk_by_rank(scores, rank).to(device) + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + score_chunk = chunk_by_rank(scores, rank, world_size).to(device) chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) #print(f"chunk_topk_ids = {chunk_topk_ids}") @@ -391,16 +394,22 @@ def chunk_by_rank(t, r): torch.distributed.barrier() - return out[:num_tokens] + #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") + + #torch.distributed.all_reduce(out) + + #print(f"AR OUT {rank}: {out.shape} {out}") + + return out[:rank_num_tokens] def _pplx_dispatch_combine( pgi: ProcessGroupInfo, dp_size: int, - m: int, - n: int, - k: int, - e: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, topk: int, dtype: torch.dtype, ): @@ -408,19 +417,18 @@ def _pplx_dispatch_combine( torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - score = torch.randn((m, e), device="cuda", dtype=dtype) + m, k = a.shape + e, _, n = w2.shape topk_weight, topk_ids = fused_topk(a, score, topk, False) - print(f"a {a.shape}") - a_rep = torch.repeat_interleave(a, topk, dim=1) - print(f"a_rep {a_rep.shape}") + #print(f"a {a.shape}") + a_rep = torch.repeat_interleave(a, topk, dim=0) + #print(f"a_rep {a_rep.shape} {a_rep.view(-1, topk, k)}") + + torch_output = (a_rep.view(-1, topk, k) * topk_weight.view(-1, topk, 1)).to(a.dtype).sum(dim=1) - torch_output = (a_rep.view(-1, topk, k) * topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) + #print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}") pplx_output = torch_pplx_dispatch_combine(pgi, dp_size, @@ -437,23 +445,25 @@ def _pplx_dispatch_combine( print("OUTPUT") print(pplx_output) + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() -# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) -# @pytest.mark.parametrize("n", [128, 1024, 2048]) -# @pytest.mark.parametrize("k", [128, 511, 1024]) -# @pytest.mark.parametrize("e", NUM_EXPERTS) -# @pytest.mark.parametrize("topk", TOP_KS) -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("m", [128]) ##, 32]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128]) -@pytest.mark.parametrize("k", [128]) -@pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) -@pytest.mark.parametrize("topk", [2]) #TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128]) # what is restriction on this? +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions here? +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("m", [2]) ##, 32]) #, 1024 * 128]) +# @pytest.mark.parametrize("n", [128]) +# @pytest.mark.parametrize("k", [128]) +# @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) +# @pytest.mark.parametrize("topk", [2]) #TOP_KS) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_pplx_dispatch_combine( m: int, n: int, @@ -469,8 +479,14 @@ def test_pplx_dispatch_combine( else: world_size = 2 dp_size = 1 + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + parallel_launch( - world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, topk, dtype + world_size, _pplx_dispatch_combine, dp_size, a, w1, w2, score, topk, dtype ) @@ -483,6 +499,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): max_num_tokens = round_up(a.shape[0], 128) #tokens_per_expert.max() print(f"max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}/{num_local_experts}") rank = pgi.rank + world_size = pgi.world_size ata = AllToAll( max_num_tokens=max_num_tokens, @@ -520,14 +537,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): experts, ) - def chunk_by_rank(t, r): - num = t.shape[0] - assert num % pgi.world_size == 0, f"{num}, {dp_size}" # for now - chunk = num // pgi.world_size - return t[(r * chunk):(r + 1)*chunk] - - a_chunk = chunk_by_rank(a, rank) - chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, chunk_by_rank(scores, rank), topk, False) + a_chunk = chunk_by_rank(a, rank, world_size) + score_chunk = chunk_by_rank(scores, rank, world_size) + chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) print(f"chunk_topk_ids = {chunk_topk_ids}") From cb7320d8f74d3a88acef33221ed5d2d7b59912d5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 17 Apr 2025 13:08:04 +0000 Subject: [PATCH 143/171] forgot file Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 104 +++++++++++++-------------------- 1 file changed, 41 insertions(+), 63 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index afb0b8858661..87c6d42862b6 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -373,7 +373,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): num_experts, # store at PplxDispatchCombine creation? None ) - torch.cuda.synchronize() # necessary? + #torch.cuda.synchronize() # necessary? out = torch.full( (max_num_tokens, hidden_dim), @@ -452,18 +452,12 @@ def _pplx_dispatch_combine( nvshmem_finalize() -@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128]) # what is restriction on this? +@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128]) @pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions here? +@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions? % 128? @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("m", [2]) ##, 32]) #, 1024 * 128]) -# @pytest.mark.parametrize("n", [128]) -# @pytest.mark.parametrize("k", [128]) -# @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) -# @pytest.mark.parametrize("topk", [2]) #TOP_KS) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_pplx_dispatch_combine( m: int, n: int, @@ -491,13 +485,16 @@ def test_pplx_dispatch_combine( def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): - hidden_dim = a.shape[-1] + assert torch.cuda.current_device() == pgi.local_rank + + num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] - num_local_experts = num_experts // pgi.world_size block_size = 128 + device = pgi.device + rank_num_tokens = num_tokens // pgi.world_size - max_num_tokens = round_up(a.shape[0], 128) #tokens_per_expert.max() - print(f"max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}/{num_local_experts}") + max_num_tokens = num_tokens + #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") rank = pgi.rank world_size = pgi.world_size @@ -523,7 +520,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): dispatch_combine = PplxDispatchCombine( ata, - max_num_tokens, + max_num_tokens, # // world_size? pgi.world_size, dp_size, rank, @@ -537,53 +534,34 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): experts, ) - a_chunk = chunk_by_rank(a, rank, world_size) - score_chunk = chunk_by_rank(scores, rank, world_size) + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + score_chunk = chunk_by_rank(scores, rank, world_size).to(device) chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - print(f"chunk_topk_ids = {chunk_topk_ids}") + #print(f"chunk_topk_ids = {chunk_topk_ids}") - # TODO: chunk up by rank - if False: - out = fused_experts( - a_chunk, - w1, # chunk? - w2, # chunk? - chunk_topk_weight, - chunk_topk_ids, - global_num_experts=num_local_experts - ) - # reduce outputs? - else: - b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( - a_chunk, - None, - None, - chunk_topk_ids, - num_experts, - None - ) - torch.cuda.synchronize() + out = fused_experts( + a_chunk, + w1, # chunk? + w2, # chunk? + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts #? num_local_experts? + ) - out = torch.full( - (max_num_tokens, hidden_dim), - torch.nan, - dtype=a.dtype, - device=a.device, - ) + torch.cuda.synchronize() - dispatch_combine.combine( - out, - b_a, - chunk_topk_weight, - chunk_topk_ids, - ) + ata.destroy() - torch.cuda.synchronize() + torch.distributed.barrier() - ata.destroy() + #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") - return out + #torch.distributed.all_reduce(out) + + print(f"OUT {rank}: {out.shape} {out}") + + return out[:rank_num_tokens] def _pplx_moe( @@ -612,29 +590,29 @@ def _pplx_moe( torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - triton_output = torch_pplx_moe(pgi, - dp_size, - a, - w1, - w2, - score, - topk) + pplxd_output = torch_pplx_moe(pgi, + dp_size, + a, + w1, + w2, + score, + topk) if False: torch.set_printoptions(profile="full") print("BASELINE") print(torch_output) print("OUTPUT") - print(triton_output) + print(pplx_output) - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() # @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) # @pytest.mark.parametrize("n", [128, 1024, 2048]) -# @pytest.mark.parametrize("k", [128, 511, 1024]) +# @pytest.mark.parametrize("k", [128, 512, 1024]) # @pytest.mark.parametrize("e", NUM_EXPERTS) # @pytest.mark.parametrize("topk", TOP_KS) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) From fe1974a393b4e31a8fb5cef677c2645ccc644036 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 18 Apr 2025 02:22:24 +0000 Subject: [PATCH 144/171] somewhat working unit test Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 137 +++++++++--------- .../layers/fused_moe/fused_moe.py | 5 +- .../layers/fused_moe/modular_kernel.py | 2 +- .../layers/fused_moe/pplx_dispatch_combine.py | 12 +- 4 files changed, 78 insertions(+), 78 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index 87c6d42862b6..f6443187f140 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -9,10 +9,8 @@ import torch import traceback -from torch.nn import Parameter -from torch.nn import functional as F from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] -from typing import Callable, Concatenate, ParamSpec +from typing import Callable, Concatenate, ParamSpec, Tuple from pplx_kernels import AllToAll from pplx_kernels.nvshmem import ( @@ -25,27 +23,18 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, torch_moe, torch_moe_single) -from vllm import _custom_ops as ops +#from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe import fused_moe +#from vllm.model_executor.layers.fused_moe import fused_moe #from vllm.model_executor.layers.fused_moe.fused_batched_moe import fused_batched_experts from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size) -from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( - fused_moe as iterative_moe) -from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - marlin_quantize) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - quantize_weights) -from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types -from vllm.utils import round_up from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts, BatchedDispatchCombine, BatchedExperts, fused_experts -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import PplxDispatchCombine NUM_EXPERTS = [8, 64] @@ -373,7 +362,8 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): num_experts, # store at PplxDispatchCombine creation? None ) - #torch.cuda.synchronize() # necessary? + + b_a = b_a * 1.5 out = torch.full( (max_num_tokens, hidden_dim), @@ -392,7 +382,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): ata.destroy() - torch.distributed.barrier() + #torch.distributed.barrier() #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") @@ -406,19 +396,26 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): def _pplx_dispatch_combine( pgi: ProcessGroupInfo, dp_size: int, - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - score: torch.Tensor, + m, n, k, e, + #a: torch.Tensor, + #w1: torch.Tensor, + #w2: torch.Tensor, + #score: torch.Tensor, topk: int, dtype: torch.dtype, ): uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) + device = pgi.device - m, k = a.shape - e, _, n = w2.shape + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 + score = torch.randn((m, e), device=device, dtype=dtype) + + #m, k = a.shape + #e, _, n = w2.shape topk_weight, topk_ids = fused_topk(a, score, topk, False) @@ -426,7 +423,7 @@ def _pplx_dispatch_combine( a_rep = torch.repeat_interleave(a, topk, dim=0) #print(f"a_rep {a_rep.shape} {a_rep.view(-1, topk, k)}") - torch_output = (a_rep.view(-1, topk, k) * topk_weight.view(-1, topk, 1)).to(a.dtype).sum(dim=1) + torch_output = (a_rep.view(-1, topk, k) * 1.5 * topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) #print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}") @@ -452,12 +449,13 @@ def _pplx_dispatch_combine( nvshmem_finalize() -@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("m", [4, 32, 64, 222]) #, 1024 * 128]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions? % 128? @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) def test_pplx_dispatch_combine( m: int, n: int, @@ -465,22 +463,14 @@ def test_pplx_dispatch_combine( e: int, topk: int, dtype: torch.dtype, + world_dp_size: Tuple[int, int], ): current_platform.seed_everything(7) - if False: - world_size = 4 - dp_size = 2 - else: - world_size = 2 - dp_size = 1 - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) + world_size, dp_size = world_dp_size parallel_launch( - world_size, _pplx_dispatch_combine, dp_size, a, w1, w2, score, topk, dtype + #world_size, _pplx_dispatch_combine, dp_size, a, w1, w2, score, topk, dtype + world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, topk, dtype ) @@ -489,9 +479,10 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] + num_local_experts = num_experts // pgi.world_size block_size = 128 device = pgi.device - rank_num_tokens = num_tokens // pgi.world_size + rank_num_tokens = num_tokens // pgi.world_size # TODO even divide max_num_tokens = num_tokens #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") @@ -518,6 +509,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): ), ) + w1 = w1.to(device) + w2 = w2.to(device) + dispatch_combine = PplxDispatchCombine( ata, max_num_tokens, # // world_size? @@ -538,28 +532,28 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): score_chunk = chunk_by_rank(scores, rank, world_size).to(device) chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - #print(f"chunk_topk_ids = {chunk_topk_ids}") + #print(f"chunk_topk_ids {rank} {chunk_topk_ids.shape} {chunk_topk_ids.view(-1)}") out = fused_experts( a_chunk, - w1, # chunk? - w2, # chunk? + w1, + w2, chunk_topk_weight, chunk_topk_ids, - global_num_experts=num_experts #? num_local_experts? + global_num_experts=num_local_experts #? num_local_experts? ) torch.cuda.synchronize() ata.destroy() - torch.distributed.barrier() + #torch.distributed.barrier() #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") #torch.distributed.all_reduce(out) - print(f"OUT {rank}: {out.shape} {out}") + #print(f"OUT {rank}: {out.shape} {out}") return out[:rank_num_tokens] @@ -567,10 +561,10 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): def _pplx_moe( pgi: ProcessGroupInfo, dp_size: int, - m: int, - n: int, - k: int, - e: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, topk: int, dtype: torch.dtype, ): @@ -578,33 +572,37 @@ def _pplx_moe( torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + m, k = a.shape + e, _, n = w2.shape - score = torch.randn((m, e), device="cuda", dtype=dtype) + torch.set_printoptions(profile="full") vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): topk_weight, topk_ids = fused_topk(a, score, topk, False) + #print(f"topk_ids {pgi.rank} {topk_ids.shape} {topk_ids.view(-1)}") + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - pplxd_output = torch_pplx_moe(pgi, - dp_size, - a, - w1, - w2, - score, - topk) + pplx_output = torch_pplx_moe(pgi, + dp_size, + a, + w1, + w2, + score, + topk) + + #print(f"torch_output {pgi.rank}: {torch_output}") if False: - torch.set_printoptions(profile="full") print("BASELINE") print(torch_output) print("OUTPUT") print(pplx_output) + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() @@ -616,12 +614,13 @@ def _pplx_moe( # @pytest.mark.parametrize("e", NUM_EXPERTS) # @pytest.mark.parametrize("topk", TOP_KS) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("m", [128]) ##, 32]) #, 1024 * 128]) +@pytest.mark.parametrize("m", [64]) ##, 32]) #, 1024 * 128]) @pytest.mark.parametrize("n", [128]) @pytest.mark.parametrize("k", [128]) @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) @pytest.mark.parametrize("topk", [2]) #TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) def test_pplx_moe( m: int, n: int, @@ -629,15 +628,17 @@ def test_pplx_moe( e: int, topk: int, dtype: torch.dtype, + world_dp_size: Tuple[int, int], ): current_platform.seed_everything(7) - if False: - world_size = 4 - dp_size = 2 - else: - world_size = 2 - dp_size = 1 + world_size, dp_size = world_dp_size + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + parallel_launch( - world_size, _pplx_moe, dp_size, m, n, k, e, topk, dtype + world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, dtype + #world_size, _pplx_moe, dp_size, m, n, k, e, topk, dtype ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a22496b7d026..cb5da93d5429 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1858,7 +1858,7 @@ def workspace_shapes( a: torch.Tensor, ) -> Tuple[int, int, torch.dtype]: max_num_tokens = a.shape[1] - workspace13 = num_experts * max_num_tokens * K + workspace13 = num_experts * max_num_tokens * K * 2 # *2 = HACK!!!!! workspace2 = max_num_tokens * (N // 2) return (workspace13, workspace2, a_dtype) @@ -1889,7 +1889,8 @@ def apply( print(f"global_num_experts = {global_num_experts}") num_experts = global_num_experts out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) - for expert in range(num_experts): + num_local_experts = expert_num_tokens.numel() + for expert in range(num_local_experts): # num_experts num = expert_num_tokens[expert] if num > 0: tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 96ecf5990a66..35f8b8292771 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -323,7 +323,7 @@ def forward( if global_num_experts == -1: global_num_experts = E - output = a1 if inplace else torch.zeros_like(a1) + output = a1 if inplace else torch.empty_like(a1) workspace13_shape, workspace2_shape, workspace_dtype = ( self.fused_experts.workspace_shapes(a1, M, N, K, top_k, diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index dd8fe4a36fba..682935e2c68b 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -78,7 +78,7 @@ def dispatch( #expert_num_tokens.fill_(-1) # debugging, remove later num_dp = self.world_size // self.dp_size - print(f"GOT HERE A {self.rank}: {self.max_num_tokens} {num_dp} {hidden_dim}") + logger.debug(f"GOT HERE A {self.rank}: {self.max_num_tokens} {num_dp} {hidden_dim}") expert_x = torch.empty( (num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]), dtype=a1q.dtype, @@ -86,7 +86,7 @@ def dispatch( ) expert_x.fill_(torch.nan) # debugging, remove later - print(f"GOT HERE B {self.rank}") + logger.debug(f"GOT HERE B {self.rank}") expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: @@ -103,7 +103,7 @@ def dispatch( device=device, ) - print(f"GOT HERE C {self.rank}") + logger.debug(f"GOT HERE C {self.rank}") # This argument is optional, defaults to indices.shape[0] # This causes a deadlock???? @@ -114,8 +114,6 @@ def dispatch( # TODO: optimize this? indices = rank_topk_ids.to(dtype=torch.uint32) - print(f"GOT HERE D {self.rank}") - self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, @@ -140,7 +138,7 @@ def combine( #device = get_dp_group().device #assert fused_expert_output.device == device - print(f"COMBINE START {self.rank}") + logger.debug(f"COMBINE START {self.rank}") # This argument is optional #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens @@ -161,4 +159,4 @@ def combine( expert_y=fused_expert_output, bound_m=bound_m) - print(f"COMBINE END {self.rank}") + logger.debug(f"COMBINE END {self.rank}") From 86c2055f1429db7845097546f35a9ad6858ef675 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 18 Apr 2025 19:31:31 +0000 Subject: [PATCH 145/171] wip Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 164 ++++++++++-------- .../layers/fused_moe/fused_moe.py | 4 +- .../layers/fused_moe/pplx_dispatch_combine.py | 2 +- 3 files changed, 93 insertions(+), 77 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index f6443187f140..b80ebfd64a09 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -164,7 +164,7 @@ def torch_dispatch( a: torch.Tensor, topk_ids: torch.Tensor, num_experts: int -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor]: assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a.shape[0] @@ -172,10 +172,11 @@ def torch_dispatch( topk = topk_ids.shape[1] tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - max_num_tokens = tokens_per_expert.max() + b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), dtype=a.dtype, device=a.device) + #print(f"b_a shape {b_a.shape}") token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) @@ -242,59 +243,58 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) -# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) -# @pytest.mark.parametrize("n", [128, 1024, 2048]) -# @pytest.mark.parametrize("k", [128, 511, 1024]) -# @pytest.mark.parametrize("e", NUM_EXPERTS) -# @pytest.mark.parametrize("topk", TOP_KS) -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# def test_fused_moe_batched_experts( -# m: int, -# n: int, -# k: int, -# e: int, -# topk: int, -# dtype: torch.dtype, -# ): -# current_platform.seed_everything(7) - -# a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 -# w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 -# w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - -# score = torch.randn((m, e), device="cuda", dtype=dtype) - -# vllm_config = VllmConfig() -# with set_current_vllm_config(vllm_config): -# topk_weight, topk_ids = fused_topk(a, score, topk, False) - -# torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - -# if True: -# triton_output = torch_batched_moe(a, -# w1, -# w2, -# topk_weight, -# topk_ids) -# else: -# b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) -# triton_output = fused_batched_experts( -# b_a, -# w1, -# w2, -# topk_weight, -# topk_ids, -# global_num_experts=e -# ) - -# if False: -# torch.set_printoptions(profile="full") -# print("BASELINE") -# print(torch_output) -# print("OUTPUT") -# print(triton_output) - -# torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) +@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids = fused_topk(a, score, topk, False) + + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + + if True: + triton_output = torch_batched_moe(a, + w1, + w2, + topk_weight, + topk_ids) + else: + b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) + triton_output = fused_batched_experts( + b_a, + w1, + w2, + topk_weight, + topk_ids, + global_num_experts=e + ) + + if False: + torch.set_printoptions(profile="full") + print("BASELINE") + print(torch_output) + print("OUTPUT") + print(triton_output) + + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) def chunk_by_rank(t, r, w): @@ -310,6 +310,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] + num_local_experts = w1.shape[0] // pgi.world_size block_size = 128 device = pgi.device rank_num_tokens = num_tokens // pgi.world_size @@ -352,7 +353,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): score_chunk = chunk_by_rank(scores, rank, world_size).to(device) chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - #print(f"chunk_topk_ids = {chunk_topk_ids}") + #print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}") b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( a_chunk, @@ -363,6 +364,25 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): None ) + #topk_weight, topk_ids = fused_topk(a_chunk, score_chunk, topk, False) + naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, num_experts) + + torch.distributed.all_reduce(tokens_per_expert) + #max_num = tokens_per_expert.max() + tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32) + + #print(f"tpe {tokens_per_expert}") + #print(f"ent {expert_num_tokens}") + + #naive_b_a = chunk_by_rank(naive_b_a, rank, world_size) + + #torch.set_printoptions(profile="full") + #print("b_a", b_a[:naive_b_a.shape[1]]) + #print("naive_b_a", naive_b_a) + + torch.testing.assert_close(tokens_per_expert, expert_num_tokens, atol=0, rtol=0) + #torch.testing.assert_close(b_a[:, :naive_b_a.shape[1]], naive_b_a, atol=2e-2, rtol=0) + b_a = b_a * 1.5 out = torch.full( @@ -382,8 +402,6 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): ata.destroy() - #torch.distributed.barrier() - #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") #torch.distributed.all_reduce(out) @@ -547,8 +565,6 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): ata.destroy() - #torch.distributed.barrier() - #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") #torch.distributed.all_reduce(out) @@ -593,8 +609,6 @@ def _pplx_moe( score, topk) - #print(f"torch_output {pgi.rank}: {torch_output}") - if False: print("BASELINE") print(torch_output) @@ -603,23 +617,25 @@ def _pplx_moe( torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + #print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}") + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() -# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) -# @pytest.mark.parametrize("n", [128, 1024, 2048]) -# @pytest.mark.parametrize("k", [128, 512, 1024]) -# @pytest.mark.parametrize("e", NUM_EXPERTS) -# @pytest.mark.parametrize("topk", TOP_KS) -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("m", [64]) ##, 32]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128]) -@pytest.mark.parametrize("k", [128]) -@pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) -@pytest.mark.parametrize("topk", [2]) #TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("m", [64]) ##, 32]) #, 1024 * 128]) +# @pytest.mark.parametrize("n", [128]) +# @pytest.mark.parametrize("k", [128]) +# @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) +# @pytest.mark.parametrize("topk", [2]) #TOP_KS) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) def test_pplx_moe( m: int, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index cb5da93d5429..bcb8212f2dcb 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1858,8 +1858,8 @@ def workspace_shapes( a: torch.Tensor, ) -> Tuple[int, int, torch.dtype]: max_num_tokens = a.shape[1] - workspace13 = num_experts * max_num_tokens * K * 2 # *2 = HACK!!!!! - workspace2 = max_num_tokens * (N // 2) + workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack + workspace2 = max_num_tokens * N return (workspace13, workspace2, a_dtype) def apply( diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 682935e2c68b..10c02fb2ff24 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -84,7 +84,7 @@ def dispatch( dtype=a1q.dtype, device=device, ) - expert_x.fill_(torch.nan) # debugging, remove later + expert_x.fill_(0) #torch.nan # debugging, remove later logger.debug(f"GOT HERE B {self.rank}") From 58fe406c51deb4cb22e30457718ed4fc659ecc6d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 18 Apr 2025 22:37:31 +0000 Subject: [PATCH 146/171] fix test Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 46 +++++++++---------- .../layers/fused_moe/fused_moe.py | 18 +++++++- 2 files changed, 39 insertions(+), 25 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index b80ebfd64a09..a62dbbcc4cd7 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -10,7 +10,7 @@ import traceback from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] -from typing import Callable, Concatenate, ParamSpec, Tuple +from typing import Callable, Concatenate, Optional, ParamSpec, Tuple from pplx_kernels import AllToAll from pplx_kernels.nvshmem import ( @@ -163,7 +163,8 @@ def parallel_launch_from_env( def torch_dispatch( a: torch.Tensor, topk_ids: torch.Tensor, - num_experts: int + num_experts: int, + max_num_tokens: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a.shape[0] @@ -172,7 +173,8 @@ def torch_dispatch( topk = topk_ids.shape[1] tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - max_num_tokens = tokens_per_expert.max() + if max_num_tokens is None: + max_num_tokens = tokens_per_expert.max() b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), dtype=a.dtype, device=a.device) @@ -314,11 +316,10 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): block_size = 128 device = pgi.device rank_num_tokens = num_tokens // pgi.world_size - - max_num_tokens = num_tokens - #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") rank = pgi.rank world_size = pgi.world_size + max_num_tokens = num_tokens + #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") ata = AllToAll( max_num_tokens=max_num_tokens, @@ -342,7 +343,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): dispatch_combine = PplxDispatchCombine( ata, - max_num_tokens, # // world_size? + max_num_tokens, pgi.world_size, dp_size, rank, @@ -353,7 +354,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): score_chunk = chunk_by_rank(scores, rank, world_size).to(device) chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - #print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}") + print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}") b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( a_chunk, @@ -371,14 +372,17 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): #max_num = tokens_per_expert.max() tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32) - #print(f"tpe {tokens_per_expert}") - #print(f"ent {expert_num_tokens}") + print(f"tpe {tokens_per_expert}") + print(f"ent {expert_num_tokens}") + + #torch.set_printoptions(profile="full") + #torch.distributed.all_reduce(naive_b_a, op=torch.distributed.ReduceOp.MAX) + #torch.distributed.broadcast(naive_b_a, src=rank) #naive_b_a = chunk_by_rank(naive_b_a, rank, world_size) - #torch.set_printoptions(profile="full") - #print("b_a", b_a[:naive_b_a.shape[1]]) - #print("naive_b_a", naive_b_a) + #print("b_a", b_a.shape, b_a) #[:, :naive_b_a.shape[1]]) + #print("naive_b_a", naive_b_a.shape, naive_b_a) torch.testing.assert_close(tokens_per_expert, expert_num_tokens, atol=0, rtol=0) #torch.testing.assert_close(b_a[:, :naive_b_a.shape[1]], naive_b_a, atol=2e-2, rtol=0) @@ -386,7 +390,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): b_a = b_a * 1.5 out = torch.full( - (max_num_tokens, hidden_dim), + (rank_num_tokens * world_size, hidden_dim), torch.nan, dtype=a.dtype, device=device, @@ -539,7 +543,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): a.dtype, ) - experts = BatchedExperts() + experts = BatchedExperts(max_num_tokens, rank) fused_experts = FusedMoEModularKernel( dispatch_combine, @@ -554,24 +558,20 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): out = fused_experts( a_chunk, - w1, - w2, + chunk_by_rank(w1, rank, world_size), + chunk_by_rank(w2, rank, world_size), chunk_topk_weight, chunk_topk_ids, - global_num_experts=num_local_experts #? num_local_experts? + global_num_experts=num_experts #? num_local_experts? ) torch.cuda.synchronize() ata.destroy() - #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") - - #torch.distributed.all_reduce(out) - #print(f"OUT {rank}: {out.shape} {out}") - return out[:rank_num_tokens] + return out[:rank_num_tokens] # chunk_by_rank? def _pplx_moe( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bcb8212f2dcb..75cc3f29ca6a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1834,6 +1834,8 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + max_num_tokens: Optional[int] = None, + rank: int = 0, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1846,6 +1848,8 @@ def __init__( assert not use_int8_w8a16 assert block_shape is None assert block_m is None + self.max_num_tokens = max_num_tokens + self.rank = rank def workspace_shapes( self, @@ -1857,7 +1861,8 @@ def workspace_shapes( num_experts: int, a: torch.Tensor, ) -> Tuple[int, int, torch.dtype]: - max_num_tokens = a.shape[1] + #assert self.max_num_tokens >= a.shape[1] + max_num_tokens = a.shape[1] if self.max_num_tokens is None else self.max_num_tokens workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack workspace2 = max_num_tokens * N return (workspace13, workspace2, a_dtype) @@ -1885,13 +1890,20 @@ def apply( assert hidden_states.dim() == 3 assert expert_num_tokens is not None num_tokens, topk = topk_ids.shape - _, max_num_tokens, K = hidden_states.shape + _, tmp_max_num_tokens, K = hidden_states.shape + max_num_tokens = tmp_max_num_tokens if self.max_num_tokens is None else self.max_num_tokens print(f"global_num_experts = {global_num_experts}") num_experts = global_num_experts out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) num_local_experts = expert_num_tokens.numel() + #assert num_local_experts >= topk_ids.view(-1).max() + #print(f"apply a={hidden_states}") + #print(f"apply topk={topk_ids}") + #print(f"apply num_tokens={expert_num_tokens}") + for expert in range(num_local_experts): # num_experts num = expert_num_tokens[expert] + assert num <= max_num_tokens if num > 0: tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) @@ -1904,6 +1916,8 @@ def apply( #print("END EXPERTS") + #print(f"apply out={out}") + return out From 4fb31ef0c777684601aad17e7f097e8d0a11ca1a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 19 Apr 2025 01:08:54 +0000 Subject: [PATCH 147/171] some cleanup Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 41 ++++++++----------- .../layers/fused_moe/fused_moe.py | 22 ++-------- .../layers/fused_moe/pplx_dispatch_combine.py | 2 +- 3 files changed, 23 insertions(+), 42 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index a62dbbcc4cd7..0e5e0cd77281 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -299,10 +299,13 @@ def test_fused_moe_batched_experts( torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) +def rank_chunk(num, r, w): + rem = num % w + return (num // w) + (1 if r < rem else 0) + + def chunk_by_rank(t, r, w): - num = t.shape[0] - assert num % w == 0, f"{num}, {w}" # for now - chunk = num // w + chunk = rank_chunk(t.shape[0], r, w) #print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") return t[(r * chunk):(r + 1)*chunk] @@ -312,12 +315,11 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] - num_local_experts = w1.shape[0] // pgi.world_size block_size = 128 device = pgi.device - rank_num_tokens = num_tokens // pgi.world_size rank = pgi.rank world_size = pgi.world_size + rank_num_tokens = rank_chunk(num_tokens, rank, world_size) max_num_tokens = num_tokens #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") @@ -354,7 +356,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): score_chunk = chunk_by_rank(scores, rank, world_size).to(device) chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}") + #print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}") b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( a_chunk, @@ -372,8 +374,8 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): #max_num = tokens_per_expert.max() tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32) - print(f"tpe {tokens_per_expert}") - print(f"ent {expert_num_tokens}") + #print(f"tpe {tokens_per_expert}") + #print(f"ent {expert_num_tokens}") #torch.set_printoptions(profile="full") #torch.distributed.all_reduce(naive_b_a, op=torch.distributed.ReduceOp.MAX) @@ -501,15 +503,12 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] - num_local_experts = num_experts // pgi.world_size block_size = 128 device = pgi.device - rank_num_tokens = num_tokens // pgi.world_size # TODO even divide - - max_num_tokens = num_tokens - #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") rank = pgi.rank world_size = pgi.world_size + rank_num_tokens = rank_chunk(num_tokens, rank, world_size) + max_num_tokens = num_tokens ata = AllToAll( max_num_tokens=max_num_tokens, @@ -558,6 +557,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): out = fused_experts( a_chunk, + # Chunking weights like this only works for batched format chunk_by_rank(w1, rank, world_size), chunk_by_rank(w2, rank, world_size), chunk_topk_weight, @@ -571,7 +571,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): #print(f"OUT {rank}: {out.shape} {out}") - return out[:rank_num_tokens] # chunk_by_rank? + return out[:rank_num_tokens] def _pplx_moe( @@ -624,18 +624,13 @@ def _pplx_moe( nvshmem_finalize() -@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 512, 1024]) +# TODO: M == 1 doesn't work +@pytest.mark.parametrize("m", [2, 3, 32, 45, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024])# , 2048]) +@pytest.mark.parametrize("k", [128, 512]) # , 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("m", [64]) ##, 32]) #, 1024 * 128]) -# @pytest.mark.parametrize("n", [128]) -# @pytest.mark.parametrize("k", [128]) -# @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) -# @pytest.mark.parametrize("topk", [2]) #TOP_KS) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) def test_pplx_moe( m: int, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 75cc3f29ca6a..ab976ed04e61 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1777,9 +1777,6 @@ def dispatch( num_tokens = a1.shape[0] topk = topk_ids.shape[1] - #assert num_experts % self.world_size == 0 - #num_local_experts = num_experts // self.world_size - tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) max_num_tokens = tokens_per_expert.max() expert_counts = torch.zeros(num_experts, dtype=torch.int, device=a1.device) @@ -1892,31 +1889,20 @@ def apply( num_tokens, topk = topk_ids.shape _, tmp_max_num_tokens, K = hidden_states.shape max_num_tokens = tmp_max_num_tokens if self.max_num_tokens is None else self.max_num_tokens - print(f"global_num_experts = {global_num_experts}") + #print(f"global_num_experts = {global_num_experts}") num_experts = global_num_experts out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) num_local_experts = expert_num_tokens.numel() - #assert num_local_experts >= topk_ids.view(-1).max() - #print(f"apply a={hidden_states}") - #print(f"apply topk={topk_ids}") - #print(f"apply num_tokens={expert_num_tokens}") + #print(f"shapes = {hidden_states.shape}, {w1.shape}, {w2.shape}, {out.shape} {expert_num_tokens.shape} {workspace2.shape} {num_experts}") for expert in range(num_local_experts): # num_experts num = expert_num_tokens[expert] - assert num <= max_num_tokens + assert num <= max_num_tokens, f"{num}, {max_num_tokens}" + #print(f"{type(num)}, {num}, {max_num_tokens}") if num > 0: tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) - # fill remainder with 0??? - #out[expert, num:, :].fill_(0) - else: - #out[expert, :, :].fill_(0) # ?? - pass - - #print("END EXPERTS") - - #print(f"apply out={out}") return out diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 10c02fb2ff24..90bfa385dacb 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -84,7 +84,7 @@ def dispatch( dtype=a1q.dtype, device=device, ) - expert_x.fill_(0) #torch.nan # debugging, remove later + #expert_x.fill_(0) #torch.nan # debugging, remove later logger.debug(f"GOT HERE B {self.rank}") From e0560d5c178dda40523e56f6cf47e05ed460c85b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 19 Apr 2025 01:49:57 +0000 Subject: [PATCH 148/171] wip Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 6 ++++-- .../layers/fused_moe/fused_moe.py | 18 +++++++++++++++--- vllm/model_executor/layers/fused_moe/layer.py | 4 ++-- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index 0e5e0cd77281..a8ce6c6dc2be 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -535,14 +535,14 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): dispatch_combine = PplxDispatchCombine( ata, - max_num_tokens, # // world_size? + max_num_tokens, pgi.world_size, dp_size, rank, a.dtype, ) - experts = BatchedExperts(max_num_tokens, rank) + experts = BatchedExperts(rank, pgi.world_size, max_num_tokens) fused_experts = FusedMoEModularKernel( dispatch_combine, @@ -560,6 +560,8 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): # Chunking weights like this only works for batched format chunk_by_rank(w1, rank, world_size), chunk_by_rank(w2, rank, world_size), + #w1, + #w2, chunk_topk_weight, chunk_topk_ids, global_num_experts=num_experts #? num_local_experts? diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ab976ed04e61..4187c122cacc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1827,12 +1827,18 @@ def combine( #print(f"END COMBINE {hex(id(self))}") +def rank_chunk(num, r, w): + rem = num % w + return (num // w) + (1 if r < rem else 0) + + class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_num_tokens: Optional[int] = None, rank: int = 0, + world_size: int = 1, + max_num_tokens: Optional[int] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1847,6 +1853,7 @@ def __init__( assert block_m is None self.max_num_tokens = max_num_tokens self.rank = rank + self.world_size = world_size def workspace_shapes( self, @@ -1895,14 +1902,19 @@ def apply( num_local_experts = expert_num_tokens.numel() #print(f"shapes = {hidden_states.shape}, {w1.shape}, {w2.shape}, {out.shape} {expert_num_tokens.shape} {workspace2.shape} {num_experts}") + # TODO: don't need world_size or rank if expert_base always == 0 + #assert w1.shape[0] == num_experts, f"{w1.shape} == {num_experts}" + #expert_base = rank_chunk(w1.shape[0], self.rank, self.world_size) * self.rank + expert_base = 0 + for expert in range(num_local_experts): # num_experts num = expert_num_tokens[expert] assert num <= max_num_tokens, f"{num}, {max_num_tokens}" #print(f"{type(num)}, {num}, {max_num_tokens}") if num > 0: tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) - self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) - out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) + self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert_base + expert].transpose(0, 1)) + out[expert, :num, :] = tmp @ w2[expert_base + expert].transpose(0, 1) return out diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 09b0f5a7e114..330bc6e6a078 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -246,8 +246,8 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine #print(f"block_m = {block_m}") if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): - logger.info("BatchedExperts") - experts = BatchedExperts() + logger.info(f"BatchedExperts {self.moe}") + experts = BatchedExperts() #rank=self.moe.ep_rank, world_size=self.moe.ep_size) else: experts = TritonExperts( use_fp8_w8a8 = False, From a87645419c786e837f04f8f9137f2f156fb07088 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 13:40:35 +0000 Subject: [PATCH 149/171] wip Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 3 -- vllm/forward_context.py | 2 +- .../layers/fused_moe/fused_moe.py | 23 +++++++------- vllm/model_executor/layers/fused_moe/layer.py | 31 ++++++++++--------- .../layers/fused_moe/triton_deep_gemm_moe.py | 7 ++--- 5 files changed, 32 insertions(+), 34 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index a8ce6c6dc2be..97fc74e3bd3c 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -23,10 +23,7 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, torch_moe, torch_moe_single) -#from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config -#from vllm.model_executor.layers.fused_moe import fused_moe -#from vllm.model_executor.layers.fused_moe.fused_batched_moe import fused_batched_experts from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size) from vllm.platforms import current_platform diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 1afdf88ec2da..99c5bba8bc47 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -93,7 +93,7 @@ def set_forward_context(attn_metadata: Any, from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) #TODO device? - max_tokens_across_dp = torch.max(num_tokens_tensor).to(device="cuda") + max_tokens_across_dp = torch.max(num_tokens_tensor) #.to(device="cuda") cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) dp_rank_num_tokens = torch.tensor( [num_tokens], diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4187c122cacc..846e9e2ec477 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1594,8 +1594,9 @@ def workspace_shapes( topk: int, num_experts: int, ) -> Tuple[int, int, torch.dtype]: - workspace1 = M * topk * max(N * 2, K) - workspace2 = M * topk * N + factor = num_experts if a.dim() == 3 else 1 + workspace1 = M * topk * max(N * 2, K) * factor + workspace2 = M * topk * N * factor return (workspace1, workspace2, a.dtype) def apply( @@ -1686,16 +1687,15 @@ def apply( global_num_experts, expert_map )) else: - #stride = hidden_states.shape[1] - sorted_token_ids = torch.arange(0, num_tokens*hidden_states.shape[1], device=hidden_states.device, dtype=torch.int) + max_num_tokens = hidden_states.shape[1] + sorted_token_ids = torch.arange(0, hidden_states.shape[0] * max_num_tokens, device=hidden_states.device, dtype=torch.int) sorted_token_ids = sorted_token_ids.flatten() - nans = torch.isnan(hidden_states).sum(dim=(1,2)) - expert_ids = torch.where((nans > 0).flatten(), -1, torch.arange(0, nans.numel(), device=hidden_states.device, dtype=torch.int32)) - #expert_ids = torch.repeat_interleave(expert_ids, hidden_states.shape[1], dim=0) - #print(f"EXPERT_IDS {nans.shape} {expert_ids}") + expert_ids = torch.arange(0, global_num_experts, device=hidden_states.device, dtype=torch.int) + expert_ids = torch.repeat_interleave(expert_ids, max_num_tokens, dim=0) + print(f"EXPERT_IDS {expert_ids}") #num_tokens_post_padded = torch.tensor([num_tokens], device=hidden_states.device, dtype=torch.int32) num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int32) - num_tokens_post_padded.fill_(num_tokens) + num_tokens_post_padded.fill_(max_num_tokens) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) #print(f"P = {sorted_token_ids}, {hidden_states.shape}") @@ -1857,19 +1857,18 @@ def __init__( def workspace_shapes( self, - a_dtype: torch.dtype, + a: torch.Tensor, M: int, N: int, K: int, topk: int, num_experts: int, - a: torch.Tensor, ) -> Tuple[int, int, torch.dtype]: #assert self.max_num_tokens >= a.shape[1] max_num_tokens = a.shape[1] if self.max_num_tokens is None else self.max_num_tokens workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack workspace2 = max_num_tokens * N - return (workspace13, workspace2, a_dtype) + return (workspace13, workspace2, a.dtype) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 330bc6e6a078..494161f12929 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -249,6 +249,7 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine logger.info(f"BatchedExperts {self.moe}") experts = BatchedExperts() #rank=self.moe.ep_rank, world_size=self.moe.ep_size) else: + logger.info(f"TritonExperts {self.moe}") experts = TritonExperts( use_fp8_w8a8 = False, use_int8_w8a16 = False, @@ -1011,21 +1012,20 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): if self.use_direct_call: return self.forward_impl(hidden_states, router_logits) - else: + elif True: return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): - max_tokens_across_dp = get_forward_context( - ).dp_metadata.max_tokens_across_dp - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - num_tokens_across_dp = get_forward_context( - ).dp_metadata.num_tokens_across_dp + ctx = get_forward_context() + + max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp + #cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu + num_tokens_across_dp = ctx.dp_metadata.num_tokens_across_dp - #print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}") + #print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{ctx.dp_metadata.dp_rank_num_tokens}") #In this function we define two ranges: # 1. chunk_range - The current iteration of the loops's range over the DP world tokens @@ -1042,17 +1042,19 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, #print(f"ORIGINAL SHAPE {full_hidden_states.shape}") #print(f"moe_dp_chunk_size_per_rank = {moe_dp_chunk_size_per_rank}") + assert full_hidden_states.shape[0] == full_router_logits.shape[0] + for iter in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - #print(f"loop {iter}: {chunk_start}:{chunk_end}, {hidden_states.shape}") - cu_tokens_across_dp_this_iter = torch.cumsum( num_tokens_remaining_across_dp.clamp( max=moe_dp_chunk_size_per_rank), dim=0) + print(f"loop {iter}: {chunk_start}:{chunk_end}, {hidden_states.shape} {cu_tokens_across_dp_this_iter}") + hidden_states = self.naive_multicast( hidden_states, cu_tokens_across_dp_this_iter) router_logits = self.naive_multicast( @@ -1087,14 +1089,14 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] - #print(f"final2 (AR) = {final_hidden_states.shape}") + print(f"final2 (AR) = {final_hidden_states.shape}") if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - #print(f"final3 (AR) = {final_hidden_states.shape}") + print(f"final3 (AR) = {final_hidden_states.shape}") full_final_hidden_states[chunk_start:chunk_end, :].copy_( final_hidden_states) @@ -1128,8 +1130,9 @@ def forward_impl(self, hidden_states: torch.Tensor, assert self.quant_method is not None if self.dp_size > 1: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu + print("FORWARD_IMPL") + ctx = get_forward_context() + cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu hidden_states = self.naive_multicast(hidden_states, cu_tokens_across_dp_cpu) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index be28d620f47d..e85f35141602 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -37,21 +37,20 @@ def __init__( def workspace_shapes( self, - a_dtype: torch.dtype, + a: torch.Tensor, M: int, N: int, K: int, topk: int, num_experts: int, - a: torch.Tensor, ) -> Tuple[int, int, torch.dtype]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): - return self.deep_gemm_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts, a) + return self.deep_gemm_expert.workspace_shapes(a, M, N, K, topk, num_experts) else: - return self.triton_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts, a) + return self.triton_expert.workspace_shapes(a, M, N, K, topk, num_experts) def apply( self, From 9396364aac7b448f8465ab8cb0a8b83689ab9fb2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 16:44:42 +0000 Subject: [PATCH 150/171] undo random changes Signed-off-by: Bill Nell --- csrc/custom_all_reduce.cuh | 2 +- vllm/distributed/parallel_state.py | 8 ++++---- vllm/model_executor/models/mllama.py | 25 ------------------------- 3 files changed, 5 insertions(+), 30 deletions(-) diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 186abf4712fd..44709b459776 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -602,4 +602,4 @@ class CustomAllreduce { * template void vllm::CustomAllreduce::allreduce(cudaStream_t, half *, half *, int, int, int); */ -} // namespace vllm +} // namespace vllm \ No newline at end of file diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index ade7b5183ddf..904f07dac5de 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -34,10 +34,10 @@ import torch import torch.distributed -from torch.distributed import Backend, ProcessGroup from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_get_unique_id, nvshmem_init, - nvshmem_finalize) + nvshmem_finalize, nvshmem_get_unique_id, + nvshmem_init) +from torch.distributed import Backend, ProcessGroup import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( @@ -917,6 +917,7 @@ def init_distributed_environment( PPLX_DID_INIT: bool = False + @run_once def pplx_init(rank, world_size): if world_size > 1: @@ -1131,7 +1132,6 @@ def destroy_model_parallel(): _DP = None - def destroy_distributed_environment(): global _WORLD if _WORLD: diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 971a4e695dab..0c1d61c01f91 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1245,31 +1245,6 @@ def unpack_data(self, output_tensor[i, :t.size(0)] = t return output_tensor - def unpack_data(self, - image_data: Union[List[torch.Tensor], torch.Tensor], - padding_value=0) -> torch.Tensor: - if isinstance(image_data, torch.Tensor): - # torch.Tensor - return image_data - else: - assert isinstance( - image_data[0], - torch.Tensor), "Image data is not properly batched." - # List[torch.Tensor] - bsz = len(image_data) - max_length = max(t.size(0) for t in image_data) - trailing_dims = image_data[0].shape[1:] - for data in image_data: - cur_trailing_dims = data.shape[1:] - assert cur_trailing_dims == trailing_dims - output_tensor = torch.full((bsz, max_length, *trailing_dims), - padding_value, - dtype=image_data[0].dtype, - device=image_data[0].device) - for i, t in enumerate(image_data): - output_tensor[i, :t.size(0)] = t - return output_tensor - def _parse_and_validate_image_input(self, **kwargs: object): # tensor with the same shape will be batched together by # MultiModalKwargs.batch, so pixel_values here can be: From 47f32c7f36fb7317f76d674fb98151f1d421203e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 20:29:06 +0000 Subject: [PATCH 151/171] merge Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 180 +++--------------- tests/kernels/moe/test_triton_moe_ptpc_fp8.py | 34 ++-- tests/kernels/quantization/test_block_fp8.py | 32 +--- .../layers/fused_moe/fused_batched_moe.py | 17 +- .../layers/fused_moe/fused_moe.py | 115 ++++------- vllm/model_executor/layers/fused_moe/layer.py | 51 +---- .../layers/fused_moe/modular_kernel.py | 14 +- .../layers/fused_moe/pplx_dispatch_combine.py | 29 +-- 8 files changed, 105 insertions(+), 367 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 1807e1b22be7..acf3636e77b9 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -108,143 +108,6 @@ def test_fused_moe( rtol=0) -def torch_dispatch( - a: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int -) -> torch.Tensor: - assert topk_ids.dim() == 2 - assert topk_ids.shape[0] == a.shape[0] - - num_tokens = a.shape[0] - topk = topk_ids.shape[1] - - tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - - max_num_tokens = tokens_per_expert.max() - b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), - dtype=a.dtype, device=a.device) - #print(f"b_a shape {b_a.shape}") - - token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) - - for token in range(num_tokens): - for j in range(topk): - expert_id = topk_ids[token, j] - idx = token_counts[expert_id] - b_a[expert_id, idx:idx+1, :] = a[token, :] - token_counts[expert_id] = token_counts[expert_id] + 1 - - return b_a, tokens_per_expert - - -def torch_combine(b_out, topk_weight, topk_ids): - num_tokens, topk = topk_ids.shape - num_experts = b_out.shape[0] - K = b_out.shape[-1] - out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) - expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) - for token in range(num_tokens): - expert_ids = topk_ids[token] - for i in range(expert_ids.numel()): - expert_id = expert_ids[i] - idx = expert_counts[expert_id] - out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i] - expert_counts[expert_id] = expert_counts[expert_id] + 1 - - return out - - -def torch_batched_moe(a, w1, w2, topk_weight, topk_ids): - num_experts = w1.shape[0] - b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts) - assert b_a.dim() == 3 - num_tokens, topk = topk_ids.shape - _, max_num_tokens, K = b_a.shape - assert num_experts == b_a.shape[0] and K == w2.shape[1] - out = torch.zeros((num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device) - tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device) - for expert in range(num_experts): - num = tokens_per_expert[expert] - if num > 0: - torch.ops._C.silu_and_mul(tmp[:num], b_a[expert,:num,:] @ w1[expert].transpose(0, 1)) - out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) - - return torch_combine(out, topk_weight, topk_ids) - - -# TODO: same as torch_moe but with fused_topk factored out. -def torch_moe2(a, w1, w2, topk_weight, topk_ids): - M, K = a.shape - topk = topk_ids.shape[1] - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) - num_experts = w1.shape[0] - for i in range(num_experts): - mask = (topk_ids == i).view(-1) - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - -@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 511, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_fused_moe_batched_experts( - m: int, - n: int, - k: int, - e: int, - topk: int, - dtype: torch.dtype, -): - current_platform.seed_everything(7) - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - score = torch.randn((m, e), device="cuda", dtype=dtype) - - vllm_config = VllmConfig() - with set_current_vllm_config(vllm_config): - topk_weight, topk_ids = fused_topk(a, score, topk, False) - - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - - if True: - triton_output = torch_batched_moe(a, - w1, - w2, - topk_weight, - topk_ids) - else: - b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) - triton_output = fused_batched_experts( - b_a, - w1, - w2, - topk_weight, - topk_ids, - global_num_experts=e - ) - - if False: - torch.set_printoptions(profile="full") - print("BASELINE") - print(torch_output) - print("OUTPUT") - print(triton_output) - - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) - - @pytest.mark.parametrize("m", [1, 32, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 1024]) @@ -570,27 +433,28 @@ def test_fused_marlin_moe( topk_weights, topk_ids = fused_topk(a, score, topk, False) - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) - - marlin_output = torch.ops.vllm.fused_marlin_moe( - a, - qweight1, - qweight2, - scales1, - scales2, - score, - topk_weights, - topk_ids, - global_num_experts=e, - expert_map=e_map, - g_idx1=g_idx1, - g_idx2=g_idx2, - sort_indices1=sort_indices1, - sort_indices2=sort_indices2, - w1_zeros=zeros1, - w2_zeros=zeros2, - num_bits=num_bits, - is_k_full=is_k_full) + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) + + marlin_output = torch.ops.vllm.fused_marlin_moe( + a, + qweight1, + qweight2, + scales1, + scales2, + score, + topk_weights, + topk_ids, + global_num_experts=e, + expert_map=e_map, + g_idx1=g_idx1, + g_idx2=g_idx2, + sort_indices1=sort_indices1, + sort_indices2=sort_indices2, + w1_zeros=zeros1, + w2_zeros=zeros2, + num_bits=num_bits, + is_k_full=is_k_full) torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0) diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py index 44734e9340aa..3b5838a99fa1 100644 --- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py +++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py @@ -7,6 +7,7 @@ import torch from vllm import _custom_ops as ops +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.platforms import current_platform @@ -15,6 +16,10 @@ pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): """Matrix multiplication function that supports per-token input @@ -137,20 +142,21 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale score = torch.randn((M, E), dtype=dtype) - ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, # using fp8 - per_channel_quant=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=None, # Not using block quantization - ) + with set_current_vllm_config(vllm_config): + ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, # using fp8 + per_channel_quant=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=None, # Not using block quantization + ) # Check results rel_diff = (torch.mean( diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index d0eca89c04e0..78200666378e 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -30,6 +30,10 @@ pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + # Test configurations DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] NUM_TOKENS = [7, 83, 2048] @@ -210,10 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) # Set the context to avoid lots of warning spam. - vllm_config = VllmConfig() - vllm_config.scheduler_config.max_num_seqs = 128 - vllm_config.scheduler_config.max_model_len = 8192 - with set_current_vllm_config(vllm_config): out = fused_moe( a, @@ -261,6 +261,7 @@ def per_block_cast_to_fp8( @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) +@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # only aligned sizes @@ -425,26 +426,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - if True: - dgm = modular_deep_gemm_fused_moe_fp8() - - def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids): - return dgm(a, - w1, - w2, - topk_weights, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s) - else: - deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 - # Set the context to avoid lots of warning spam. - vllm_config = VllmConfig() - vllm_config.scheduler_config.max_num_seqs = 128 - vllm_config.scheduler_config.max_model_len = 8192 - with set_current_vllm_config(vllm_config): if M >= 128: ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, @@ -455,8 +437,8 @@ def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) + out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 56b1b343c86e..e3279cd37f2c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -24,7 +24,7 @@ def dispatch( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool = False, + apply_router_weight_on_input: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a1.shape[0] @@ -99,8 +99,6 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - rank: int = 0, - world_size: int = 1, max_num_tokens: Optional[int] = None, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -116,8 +114,6 @@ def __init__( assert block_shape is None assert block_m is None self.max_num_tokens = max_num_tokens - self.rank = rank - self.world_size = world_size assert not use_fp8_w8a8, "NYI" assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" @@ -171,12 +167,6 @@ def apply( (num_experts, max_num_tokens, w2.shape[1])) num_local_experts = expert_num_tokens.numel() - # TODO: don't need world_size or rank if expert_base always == 0 - #assert w1.shape[0] == num_experts, f"{w1.shape} == {num_experts}" - #expert_base = rank_chunk(w1.shape[0], self.rank, - # self.world_size) * self.rank - expert_base = 0 - for expert in range(num_local_experts): num = expert_num_tokens[expert] assert num <= max_num_tokens, f"{num}, {max_num_tokens}" @@ -184,8 +174,7 @@ def apply( tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) self.activation( activation, tmp, hidden_states[expert, :num, :] - @ w1[expert_base + expert].transpose(0, 1)) - out[expert, :num, :] = tmp @ w2[expert_base + - expert].transpose(0, 1) + @ w1[expert].transpose(0, 1)) + out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) return out diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 846e9e2ec477..5bf49a8c2c28 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -3,7 +3,6 @@ import functools import json import os -from math import prod from typing import Any, Callable, Dict, List, Optional, Tuple import torch @@ -23,19 +22,12 @@ from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op, round_up +from vllm.utils import direct_register_custom_op from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled logger = init_logger(__name__) -has_deep_gemm = False -try: - import deep_gemm as dg - has_deep_gemm = True -except ImportError: - pass - @triton.jit def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, @@ -494,7 +486,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 - if use_fp8_w8a8: + if use_fp8_w8a8 or use_int8_w8a8: assert B_scale is not None assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) == B_scale.shape[-2]) @@ -511,20 +503,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, M = A.shape[0] num_tokens = M * top_k - if use_fp8_w8a8: - assert B_scale is not None - assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) - == B_scale.shape[-2]) - assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) - == B_scale.shape[-1]) - - elif use_int8_w8a16 or use_int4_w4a16: - assert B_scale is not None - assert block_shape is None or block_shape[0] == 0 - else: - assert A_scale is None - assert B_scale is None - EM = sorted_token_ids.shape[0] if A.shape[0] < config["BLOCK_SIZE_M"]: # optimize for small batch_size. @@ -1063,8 +1041,7 @@ def inplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, - allow_deep_gemm: bool = False) -> None: + block_shape: Optional[List[int]] = None) -> None: pass @@ -1098,8 +1075,7 @@ def outplace_fused_experts( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, - allow_deep_gemm: bool = False) -> torch.Tensor: + block_shape: Optional[List[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, @@ -1129,8 +1105,7 @@ def outplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, - allow_deep_gemm: bool = False) -> torch.Tensor: + block_shape: Optional[List[int]] = None) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1214,7 +1189,6 @@ def fused_experts(hidden_states: torch.Tensor, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=use_fp8_w8a8, @@ -1302,6 +1276,19 @@ def fused_experts_impl(hidden_states: torch.Tensor, config = get_config_func(M) + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty(M * top_k_num * max(N, K), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N) + intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K) + + # This needs separate memory since it's used concurrently with cache1 + intermediate_cache2 = torch.empty((M * top_k_num, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 elif hidden_states.dtype == torch.float16: @@ -1316,50 +1303,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - block_m = config['BLOCK_SIZE_M'] - assert not use_dg or block_m == dg.get_m_alignment_for_contiguous_layout() - - cache1_view: Tuple[int, ...] = () - cache2_view: Tuple[int, ...] = () - cache3_view: Tuple[int, ...] = () - - if use_dg: - assert w1_scale is not None - assert w2_scale is not None - - # We attempt to transpose and align offline in Fp8MoEMethod, in which - # case these calls will be nops. Otherwise, they'll be performed every - # time the layer is executed. - w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() - w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - - M_sum = topk_ids.numel() + global_num_experts * (block_m - 1) - M_sum = round_up(M_sum, block_m) - - cache1_view = (M_sum, N) - cache3_view = (M_sum, K) - else: - M_sum = M * top_k_num - cache1_view = (M, top_k_num, N) - cache3_view = (M, top_k_num, K) - - num_chunks = (num_tokens // CHUNK_SIZE) + 1 - - # We can reuse the memory between cache1 and cache3 because by the time - # we need cache3, we're done with cache1 - cache13 = torch.empty(M_sum * max(N, K), - device=hidden_states.device, - dtype=hidden_states.dtype) - - intermediate_cache1 = cache13[:M_sum * N].view(*cache1_view) - intermediate_cache2 = torch.empty((M_sum, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache3 = cache13[:M_sum * K].view(*cache3_view) - - needs_fp8_quantization = use_fp8_w8a8 or use_dg - - for chunk in range(num_chunks): + for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, num_tokens)) @@ -1369,6 +1313,17 @@ def fused_experts_impl(hidden_states: torch.Tensor, if tokens_in_chunk == 0: break + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + # Adjust the intermediate cache size and config for the last + # chunk. Note that in most cases we only have one chunk + # so the cache size and config are already set correctly and + # do not need to be adjusted. + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * + topk_ids.shape[1]] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + config = get_config_func(tokens_in_chunk) + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] @@ -1380,8 +1335,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, block_m, global_num_experts, - expert_map)) + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) invoke_fused_moe_kernel(qcurr_hidden_states, w1, @@ -1667,9 +1622,6 @@ def apply( raise ValueError( f"Unsupported compute_type: {hidden_states.dtype}") - #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}") - #print(f"BLOCK_M = {self.block_m}") - # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 intermediate_cache1 = _resize_cache(workspace13, @@ -1720,8 +1672,7 @@ def apply( per_channel_quant=self.per_channel_quant, block_shape=self.block_shape) - self.activation(activation, - intermediate_cache2, + self.activation(activation, intermediate_cache2, intermediate_cache1.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 494161f12929..a125715486df 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Callable, List, Optional, Tuple -import pplx_kernels as pplx +import pplx_kernels as pplx # TODO: guard this import torch import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter @@ -243,19 +243,20 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine assert self.fused_experts == fused_experts block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) - #print(f"block_m = {block_m}") if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): logger.info(f"BatchedExperts {self.moe}") - experts = BatchedExperts() #rank=self.moe.ep_rank, world_size=self.moe.ep_size) + experts = BatchedExperts() else: logger.info(f"TritonExperts {self.moe}") experts = TritonExperts( use_fp8_w8a8 = False, + use_int8_w8a8 = False, use_int8_w8a16 = False, use_int4_w4a16 = False, block_shape = None, block_m = None, #block_m, + per_channel_quant = False, ) self.fused_experts = FusedMoEModularKernel( @@ -526,7 +527,7 @@ def __init__( # Use expert parallelism instead of tensor parallelism? vllm_config = get_current_vllm_config() use_ep = (vllm_config.parallel_config.enable_expert_parallel - and (self.tp_size * self.dp_size) > 1) + and self.tp_size * self.dp_size > 1) # For smuggling this layer into the fused moe custom op self.use_direct_call = self.dp_size == 1 @@ -557,7 +558,6 @@ def __init__( self.ep_size = 1 self.local_num_experts = self.global_num_experts self.expert_map = None - #self.global_num_experts = num_experts redundant? self.top_k = top_k assert intermediate_size % self.tp_size == 0 @@ -578,23 +578,20 @@ def __init__( if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") - if current_platform.is_hpu(): from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) - #print(f"params dtype= {params_dtype}") - moe = MoEConfig( num_experts=self.global_num_experts, - experts_per_token=top_k, # ? must be same as topk_ids.shape[1] + experts_per_token=top_k, hidden_dim=hidden_size, num_local_experts=self.local_num_experts, dp_size=self.dp_size, dp_rank=self.dp_rank, ep_size=self.ep_size, ep_rank=self.ep_rank, - in_dtype = params_dtype, # this is probably not right, where to get? + in_dtype = params_dtype, # this is probably not right, where to get? out_dtype = params_dtype, # ditto. ) @@ -619,14 +616,6 @@ def __init__( dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. rank = moe.ep_rank - if False: - print(f"max num = {max_num_tokens}") - print(f"world size = {world_size}") - print(f"moe ep size = {moe.ep_size}") - print(f"moe dp size = {moe.dp_size}") - print(f"dp size = {dp_size}") - print(f"rank= {rank}") - all_to_all = get_all_to_all( max_num_tokens=max_num_tokens, num_experts=moe.num_experts, @@ -657,7 +646,7 @@ def __init__( rank, # just for debugging moe.in_dtype, ) - elif False: + elif True: logger.info("using standard dispatch") dispatch_combine = StandardDispatchCombine( moe.in_dtype, @@ -1012,7 +1001,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): if self.use_direct_call: return self.forward_impl(hidden_states, router_logits) - elif True: + else: return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) @@ -1022,11 +1011,9 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, ctx = get_forward_context() max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp - #cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu + cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu num_tokens_across_dp = ctx.dp_metadata.num_tokens_across_dp - #print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{ctx.dp_metadata.dp_rank_num_tokens}") - #In this function we define two ranges: # 1. chunk_range - The current iteration of the loops's range over the DP world tokens # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. @@ -1039,9 +1026,6 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_hidden_states.shape[0]) full_final_hidden_states = torch.empty_like(full_hidden_states) - #print(f"ORIGINAL SHAPE {full_hidden_states.shape}") - #print(f"moe_dp_chunk_size_per_rank = {moe_dp_chunk_size_per_rank}") - assert full_hidden_states.shape[0] == full_router_logits.shape[0] for iter in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): @@ -1053,8 +1037,6 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, max=moe_dp_chunk_size_per_rank), dim=0) - print(f"loop {iter}: {chunk_start}:{chunk_end}, {hidden_states.shape} {cu_tokens_across_dp_this_iter}") - hidden_states = self.naive_multicast( hidden_states, cu_tokens_across_dp_this_iter) router_logits = self.naive_multicast( @@ -1078,8 +1060,6 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, activation=self.activation, ) - #print(f"final1 = {final_hidden_states.shape}") - if self.dp_size > 1: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[ self.dp_rank - 1] @@ -1089,27 +1069,19 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] - print(f"final2 (AR) = {final_hidden_states.shape}") - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - print(f"final3 (AR) = {final_hidden_states.shape}") - full_final_hidden_states[chunk_start:chunk_end, :].copy_( final_hidden_states) - #print(f"partial final = {full_final_hidden_states.shape}") - # Update bounds num_tokens_remaining_across_dp = torch.clamp( num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, min=0) - #print(f"num remaining = {num_tokens_remaining_across_dp}") - # HACK FIX if num_tokens_remaining_across_dp.sum() == 0: break @@ -1121,8 +1093,6 @@ def update_chunk_bound(x: int): chunk_start = update_chunk_bound(chunk_start) chunk_end = update_chunk_bound(chunk_end) - #print(f"full final shape {full_final_hidden_states.shape}") - return full_final_hidden_states def forward_impl(self, hidden_states: torch.Tensor, @@ -1130,7 +1100,6 @@ def forward_impl(self, hidden_states: torch.Tensor, assert self.quant_method is not None if self.dp_size > 1: - print("FORWARD_IMPL") ctx = get_forward_context() cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 35f8b8292771..d550c8b040c9 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -76,7 +76,6 @@ def _moe_problem_size( return E, M, N, K, topk - class FusedMoEQuantizeDispatchCombine(ABC): """ An abstract base class for the [Quantize-Dispatch] and [Combine] steps @@ -107,7 +106,8 @@ def dispatch( - num_experts: The total number of experts in the global expert space. - expert_map: A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. + - apply_router_weight_on_input: When True, apply the weights to the + activations, before quantization + dispatching. Returns a tuple of: - quantized + dispatched a. @@ -132,7 +132,8 @@ def combine( experts, it will have (M, topk, K) shape. - topk_weights: The weights to be applied to the fused_experts_output. - topk_ids: The topk_ids. - - apply_router_weight_on_input: When False, apply the weights to fused_expert_output. + - apply_router_weight_on_input: When False, apply the weights to + fused_expert_output. """ raise NotImplementedError @@ -312,14 +313,9 @@ def forward( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ - #from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank) - #print(f"START {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") - a1 = hidden_states E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) - #print(f"INIT shape: E={E}, M={M}, N={N}, K={K}, top_k={top_k}") - if global_num_experts == -1: global_num_experts = E @@ -364,6 +360,4 @@ def forward( self.dispatch_combine.combine(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) - #print(f"DONE {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") - return output diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 90bfa385dacb..ef5da7a5d9e3 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -9,8 +9,6 @@ moe_kernel_quantize_input) -logger = init_logger(__name__) - # Note use: layer.get_all_to_all() to get an AllToAll instance # The max_num_tokens, world_size and dp_size must be the same # as the ones used to create the AllToAll. @@ -46,7 +44,6 @@ def dispatch( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: # Is this always going to be a1.device? device = a1.device - num_tokens = a1.shape[0] # M hidden_dim = a1.shape[-1] # K assert expert_map is None, "NYI" @@ -75,18 +72,13 @@ def dispatch( dtype=torch.int32, device=device, ) - #expert_num_tokens.fill_(-1) # debugging, remove later num_dp = self.world_size // self.dp_size - logger.debug(f"GOT HERE A {self.rank}: {self.max_num_tokens} {num_dp} {hidden_dim}") expert_x = torch.empty( - (num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]), + (num_local_experts, self.max_num_tokens * num_dp, hidden_dim), dtype=a1q.dtype, device=device, ) - #expert_x.fill_(0) #torch.nan # debugging, remove later - - logger.debug(f"GOT HERE B {self.rank}") expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: @@ -103,11 +95,10 @@ def dispatch( device=device, ) - logger.debug(f"GOT HERE C {self.rank}") - # This argument is optional, defaults to indices.shape[0] - # This causes a deadlock???? + # This causes a deadlock? #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens + #num_tokens = a1.shape[0] # M #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) bound_m = None @@ -133,23 +124,17 @@ def combine( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, ) -> None: - device = fused_expert_output.device - #device = torch.device("cuda", self.rank) - #device = get_dp_group().device - #assert fused_expert_output.device == device - - logger.debug(f"COMBINE START {self.rank}") - # This argument is optional #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens #num_tokens = fused_expert_output.shape[0] # M - #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) + #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, + # device=fused_expert_output.device) bound_m = None assert output.shape[0] <= self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] - # Set weights to 1 if we did them in dispatch. This is hacky. + # Set weights to 1 if we did them in dispatch. This is hacky. if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) @@ -158,5 +143,3 @@ def combine( weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m) - - logger.debug(f"COMBINE END {self.rank}") From 00f8fb2c760787c076c9df47cd2539585bf8b5f2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 20:37:43 +0000 Subject: [PATCH 152/171] tweak Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 405ced54d2ee..696a1cb4d60b 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -523,13 +523,9 @@ def _pplx_moe( m, k = a.shape e, _, n = w2.shape - torch.set_printoptions(profile="full") - with set_current_vllm_config(vllm_config): topk_weight, topk_ids = fused_topk(a, score, topk, False) - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - pplx_output = torch_pplx_moe(pgi, dp_size, a, w1, w2, score, topk) torch_output = chunk_by_rank(torch_output, pgi.rank, From fd4805fc7f658d7ebe01adc9a5b915eca022fe02 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 21:22:58 +0000 Subject: [PATCH 153/171] revert hack Signed-off-by: Bill Nell --- examples/offline_inference/data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 1c0701051890..965915beaf58 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -160,7 +160,7 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, procs.append(proc) exit_code = 0 for proc in procs: - proc.join(timeout=3000) + proc.join(timeout=300) if proc.exitcode is None: print(f"Killing process {proc.pid} that " f"didn't stop within 5 minutes.") From be22c575057744ec07de1f3530ddb2478ac9f015 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 21:39:44 +0000 Subject: [PATCH 154/171] fixes Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 696a1cb4d60b..b58c2d2c6d3f 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -471,10 +471,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): pgi.world_size, dp_size, rank, - a.dtype, ) - experts = BatchedExperts(rank, pgi.world_size, max_num_tokens) + experts = BatchedExperts(max_num_tokens) fused_experts = FusedMoEModularKernel( dispatch_combine, From 9018df8198523013d79b0f37cc778de277f91c63 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 21:45:57 +0000 Subject: [PATCH 155/171] pplx update Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index b58c2d2c6d3f..aeedadea3852 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -300,7 +300,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): rank_num_tokens = rank_chunk(num_tokens, rank, world_size) max_num_tokens = num_tokens - ata = AllToAll( + ata = AllToAll.internode( max_num_tokens=max_num_tokens, num_experts=num_experts, experts_per_token=topk, @@ -448,7 +448,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): rank_num_tokens = rank_chunk(num_tokens, rank, world_size) max_num_tokens = num_tokens - ata = AllToAll( + ata = AllToAll.internode( max_num_tokens=max_num_tokens, num_experts=num_experts, experts_per_token=topk, From 3433b7394e3ef12d8dc2255ecd7dd2f5e685e653 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 22:17:50 +0000 Subject: [PATCH 156/171] varun's fixes Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 158 +++++ vllm/distributed/parallel_state.py | 4 +- .../layers/fused_moe/fused_batched_moe.py | 627 +++++++++++++++++- vllm/model_executor/layers/fused_moe/layer.py | 52 +- .../layers/fused_moe/pplx_dispatch_combine.py | 18 +- vllm/model_executor/models/deepseek_v2.py | 4 +- 6 files changed, 824 insertions(+), 39 deletions(-) create mode 100644 tests/kernels/moe/test_batched_moe.py diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py new file mode 100644 index 000000000000..ffd69935b461 --- /dev/null +++ b/tests/kernels/moe/test_batched_moe.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + +import pytest +from dataclasses import dataclass + +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + invoke_moe_batched_triton_kernel, + invoke_batched_silu_and_mul) + + +@dataclass +class BatchedMMConfig: + dtype: torch.dtype + num_experts: int + max_tokens_per_expert: int + K: int + N: int + +@dataclass +class BatchedMMTensors: + A: torch.Tensor # [E, max_tokens, K] + B: torch.Tensor # [E, K, N] - column major + C: torch.Tensor # [E, max_tokens, N] + num_expert_tokens: torch.Tensor # [E] + + @staticmethod + def make_tensors(config: BatchedMMConfig): + A = torch.randn((config.num_experts, config.max_tokens_per_expert, config.K), device="cuda", dtype=config.dtype) / 50.0 + B = torch.randn((config.num_experts, config.N, config.K), device="cuda", dtype=config.dtype) / 50.0 + C = torch.zeros((config.num_experts, config.max_tokens_per_expert, config.N), device="cuda", dtype=config.dtype) + num_expert_tokens=torch.randint(low = 0, high = config.max_tokens_per_expert, size=(config.num_experts,), device="cuda", dtype=torch.int32) + return BatchedMMTensors(A,B,C, num_expert_tokens) + + +def ref_impl(A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + num_expert_tokens: torch.Tensor) -> torch.Tensor: + + num_expert_tokens_cpu = num_expert_tokens.clone() + num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") + num_experts = num_expert_tokens.size(0) + + for e in range(num_experts): + num_tokens = num_expert_tokens_cpu[e] + C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) + + + return C + +@pytest.mark.parametrize("num_experts", [16, 32]) +@pytest.mark.parametrize("max_tokens_per_expert", [512]) +@pytest.mark.parametrize("K", [256]) +@pytest.mark.parametrize("N", [512]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_batched_mm(num_experts: int, + max_tokens_per_expert: int, + K: int, + N: int, + dtype: torch.dtype): + + config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) + tensors = BatchedMMTensors.make_tensors(config) + + test_output = tensors.C + ref_output = test_output.clone() + + + compute_tl_dtype = {torch.float16 : tl.float16, + torch.bfloat16 : tl.bfloat16, + torch.float32 : tl.float32}[test_output.dtype] + invoke_moe_batched_triton_kernel(tensors.A, + tensors.B, + test_output, + tensors.num_expert_tokens, + compute_tl_dtype, + # Quantization data + None, + None, + None, + # Quantization schemes + False, + False, + False, + config = {"BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 16}) + + + ref_output = ref_impl(tensors.A, tensors.B, ref_output, tensors.num_expert_tokens) + #torch.cuda.synchronize() + #print (f"ref output {ref_output}") + #print (f"test output {test_output}") + + torch.testing.assert_close(test_output, ref_output, atol=1e-3, rtol=1e-3) + + +@dataclass +class BatchedSiluMulConfig: + dtype: torch.dtype + num_experts: int + max_tokens_per_expert: int + D: int + +@dataclass +class BatchedSiluMulTensors: + input: torch.Tensor + output: torch.Tensor + expert_num_tokens: torch.Tensor + + @staticmethod + def make_tensors(config: BatchedSiluMulConfig): + input = torch.randn((config.num_experts, config.max_tokens_per_expert, config.D * 2), device="cuda", dtype=config.dtype) / 50.0 + output = torch.zeros((config.num_experts, config.max_tokens_per_expert, config.D), device="cuda", dtype=config.dtype) + num_expert_tokens=torch.randint(low = 0, high = config.max_tokens_per_expert, size=(config.num_experts,), device="cuda", dtype=torch.int32) + return BatchedSiluMulTensors(input, output, num_expert_tokens) + + +def ref_batched_silu_mul( + output: torch.Tensor, + input: torch.Tensor, + num_expert_tokens: torch.Tensor) -> torch.Tensor: + + num_expert_tokens_cpu = num_expert_tokens.clone() + num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") + num_experts = num_expert_tokens.size(0) + + for e in range(num_experts): + num_tokens = num_expert_tokens_cpu[e].item() + out_part = output[e, :num_tokens, :] + in_part = input[e, :num_tokens, :] + torch.ops._C.silu_and_mul(out_part, in_part) + + +@pytest.mark.parametrize("num_experts", [16, 32]) +@pytest.mark.parametrize("max_tokens_per_expert", [128]) +@pytest.mark.parametrize("D", [128, 256]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_batched_silu_mul(num_experts: int, + max_tokens_per_expert: int, + D: int, + dtype: torch.dtype): + + config = BatchedSiluMulConfig(dtype, num_experts, max_tokens_per_expert, D) + tensors = BatchedSiluMulTensors.make_tensors(config) + + test_out = tensors.output + ref_out = torch.zeros_like(test_out) + + ref_batched_silu_mul(ref_out, tensors.input, tensors.expert_num_tokens) + + invoke_batched_silu_and_mul(test_out, tensors.input, tensors.expert_num_tokens) + + torch.testing.assert_close(test_out, ref_out) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 904f07dac5de..cf715681c878 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -923,12 +923,12 @@ def pplx_init(rank, world_size): if world_size > 1: try: global PPLX_DID_INIT - print(f"PPLX_INIT {rank} {world_size}") + logger.debug(f"PPLX_INIT {rank} {world_size}") uid = nvshmem_get_unique_id( ) if rank == 0 else nvshmem_alloc_empty_unique_id() uid_gpu = uid.cuda() get_world_group().broadcast(uid_gpu, src=0) - print(f"PPLX_INIT UID={uid_gpu}") + logger.debug(f"PPLX_INIT UID={uid_gpu}") uid = uid_gpu.to(device='cpu') nvshmem_init(uid, rank, world_size) PPLX_DID_INIT = True diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index e3279cd37f2c..907670cbb7b8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -3,9 +3,465 @@ from typing import List, Optional, Tuple import torch +import triton +import triton.language as tl import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.fused_moe.fused_moe import ( + get_config_dtype_str, + try_get_optimal_moe_config, +) + +@triton.jit +def batched_silu_and_mul_kernel(output, # [E, MAX_NUM_TOKENS, D] + input, # [E, MAX_NUM_TOKENS, D * 2] + expert_num_tokens, # [E] + stride_oe, + stride_om, + stride_ie, + stride_im, + compute_type: tl.constexpr, + D, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr): + + expert_id = tl.program_id(axis=0) + e_num_tokens = tl.load(expert_num_tokens + expert_id) + if e_num_tokens == 0: + # early exit + return + + pid_m = tl.program_id(axis=1) + cta_m_start = pid_m * BLOCK_M + if cta_m_start >= e_num_tokens: + # early exit + return + + cta_input_ptr = input + expert_id * stride_ie + cta_m_start * stride_im + cta_output_ptr = output + expert_id * stride_oe + cta_m_start * stride_om + + cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start) + offs_m = tl.arange(0, BLOCK_M)[:, None] + mask_m = offs_m < cta_m_size + + cta_input_ptrs = cta_input_ptr + offs_m * stride_im + cta_output_ptrs = cta_output_ptr + offs_m * stride_om + + # offset by D + offs_D = tl.arange(0, BLOCK_D) + cta_input_ptrs = cta_input_ptrs + offs_D + cta_output_ptrs = cta_output_ptrs + offs_D + + for d in range(0, tl.cdiv(D, BLOCK_D)): + mask_D = offs_D < (D - (d * BLOCK_D)) + mask_tile = mask_m & mask_D + + x_tile = tl.load(cta_input_ptrs, mask=mask_tile, other=0.0).to(dtype=tl.float32) + y_tile = tl.load(cta_input_ptrs + D, mask=mask_tile, other=0.0) + + # silu and mul + out_tile = (x_tile * (1.0 / (1.0 + tl.exp(-x_tile)))).to(dtype=compute_type) + out_tile = out_tile * y_tile + tl.store(cta_output_ptrs, out_tile, mask=mask_tile) + + cta_input_ptrs = cta_input_ptrs + BLOCK_D + cta_output_ptrs = cta_output_ptrs + BLOCK_D + +@triton.jit +def moe_mmk( + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + compute_type: tl.constexpr, + use_w8a8: tl.constexpr, + use_w8a16: tl.constexpr): + + offs_k = tl.arange(0, BLOCK_K) + + if use_w8a16: + b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[ + None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + if use_w8a8: + # block-wise + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + offs_m * stride_asm + offs_bsn = offs_n // group_n + b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse + + offs_bsn * stride_bsn) + # tensor-wise + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + expert_id) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load(a_ptrs, + mask=mask_m[:, None] & + (offs_k[None, :] < K - k * BLOCK_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_K, + other=0.0) + # We accumulate along the K dimension. + if use_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, + mask=mask_m, + other=0.0) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, + None] * b_scale[None, :] + else: + if use_w8a8: + # acc used to enable fp8_fast_accum + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + if use_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_w8a8: + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + + return accumulator + + +@triton.jit +def expert_triton_kernel(a_ptr, #[max_tokens, K] + b_ptr, #[K, N] + c_ptr, #[max_tokens, N] + expert_id, + compute_type: tl.constexpr, + # Dimensions + M, + N, + K, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) % N + offs_k = tl.arange(0, BLOCK_K) + mask_m = offs_m < M + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn + + + accumulator = moe_mmk( + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n, + group_k, + # Meta-parameters + BLOCK_M, + BLOCK_N, + BLOCK_K, + compute_type, + use_fp8_w8a8, + use_int8_w8a16) + + # store in C + offs_cn = tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = mask_m[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + +@triton.jit +def batched_triton_kernel(a_ptr, # [E, max_num_tokens, K] + b_ptr, # [E, K, N] + c_ptr, # [E, max_num_tokens, N] + expert_num_tokens, # [E] + compute_type: tl.constexpr, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ae, + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_ce, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n: tl.constexpr, + group_k: tl.constexpr, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + expert_id = tl.program_id(axis=0) + e_num_tokens = tl.load(expert_num_tokens + expert_id) + if e_num_tokens == 0: + # Early exit + return + + pid_mn = tl.program_id(axis=1) + num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_m = pid_mn // num_pid_n + pid_n = pid_mn % num_pid_n + + cta_m_start = pid_m * BLOCK_M + cta_n_start = pid_n * BLOCK_N + if cta_m_start >= e_num_tokens: + # Early exit + return + + cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start) + cta_n_size = min(BLOCK_N, N - cta_n_start) + + a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am + b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn + c_ptr = c_ptr + expert_id * stride_ce + cta_m_start * stride_cm + cta_n_start * stride_cn + + expert_triton_kernel(a_ptr, + b_ptr, + c_ptr, + expert_id, + compute_type, + cta_m_size, # M + cta_n_size, # N + K, # K + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # Strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + # Kernel config + BLOCK_M, + BLOCK_N, + BLOCK_K) + + +def invoke_moe_batched_triton_kernel(A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, K, N] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: torch.Tensor, + B_scale: torch.Tensor, + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + config: dict[str, int], + block_shape: Optional[list[int]] = None): + + assert not use_int4_w4a16 + max_num_tokens = A.size(1) + K = A.size(2) + N = C.size(2) + + BLOCK_M = config['BLOCK_SIZE_M'] + BLOCK_N = config['BLOCK_SIZE_N'] + BLOCK_K = config['BLOCK_SIZE_K'] + assert max_num_tokens % BLOCK_M == 0 + + grid = (expert_num_tokens.size(0), + triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.shape[1], BLOCK_N)) + + batched_triton_kernel[grid](A, + B, + C, + expert_num_tokens, + compute_type, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + A_scale, + B_scale, + B_zp, + # Strides + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(0), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + # Blockwise quantization data + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + # Kernel config + BLOCK_M = BLOCK_M, + BLOCK_N = BLOCK_N, + BLOCK_K = BLOCK_K) + + +def invoke_batched_silu_and_mul(output : torch.Tensor, #[E, MAX_TOKENS, D] + input: torch.Tensor, #[E, MAX_TOKENS, D * 2] + expert_num_tokens: torch.Tensor): + + + num_experts = output.size(0) + max_num_tokens = output.size(1) + D = output.size(2) + + BLOCK_D = 1024 + BLOCK_M = 1 + + compute_tl_dtype = {torch.float16 : tl.float16, + torch.float32 : tl.float32, + torch.bfloat16 : tl.bfloat16}[output.dtype] + + #print(f"compute type {compute_tl_dtype}") + + grid = (num_experts, triton.cdiv(max_num_tokens, BLOCK_M)) + batched_silu_and_mul_kernel[grid](output, + input, + expert_num_tokens, + output.stride(0), + output.stride(1), + input.stride(0), + input.stride(1), + compute_tl_dtype, + D, + BLOCK_M, + BLOCK_D) class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): @@ -90,11 +546,6 @@ def combine( expert_counts[expert_id] = expert_counts[expert_id] + 1 -def rank_chunk(num, r, w): - rem = num % w - return (num // w) + (1 if r < rem else 0) - - class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -108,16 +559,13 @@ def __init__( block_m: Optional[int] = None, ): super().__init__() - assert not use_fp8_w8a8 - assert not use_int4_w4a16 - assert not use_int8_w8a16 assert block_shape is None assert block_m is None - self.max_num_tokens = max_num_tokens assert not use_fp8_w8a8, "NYI" assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" + self.max_num_tokens = max_num_tokens def workspace_shapes( self, @@ -178,3 +626,164 @@ def apply( out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) return out + + +class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + max_num_tokens: Optional[int] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + block_shape: Optional[List[int]] = None, + ): + super().__init__() + self.use_fp8_w8a8 = use_fp8_w8a8 + self.use_int8_w8a8 = use_int8_w8a8 + self.use_int4_w4a16 = use_int4_w4a16 + self.use_int8_w8a16 = use_int8_w8a16 + self.block_shape = block_shape + self.max_num_tokens = max_num_tokens + assert not use_int8_w8a8, "NYI" + assert not use_int4_w4a16, "NYI" + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: + max_num_tokens = a.shape[ + 1] if self.max_num_tokens is None else self.max_num_tokens + workspace13 = num_experts * max_num_tokens * max(K, N) + workspace2 = num_experts * max_num_tokens * (N // 2) + return (workspace13, workspace2, a.dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + + num_tokens = topk_ids.size(0) + #print_debug = expert_map[0] != -1 and num_tokens < 50 and num_tokens != 1 and False + + # Check constraints. + if self.use_int4_w4a16: + assert hidden_states.shape[-1] // 2 == w1.shape[ + 2], "Hidden size mismatch" + else: + assert hidden_states.shape[-1] == w1.shape[ + 2], f"Hidden size mismatch {hidden_states.shape[-1]} != {w1.shape[2]}" + + assert hidden_states.is_contiguous( + ), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + ] + + E, num_tokens, N, K, top_k_num = mk._moe_problem_size( + hidden_states, w1, w2, topk_ids) + + assert w1.shape[0] == E + assert w2.shape[0] == E + + config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + dtype=hidden_states.dtype) + + config = try_get_optimal_moe_config( + w1.shape, + w2.shape, + top_k_num, + config_dtype, + num_tokens, + block_shape=self.block_shape, + ) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + elif hidden_states.dtype == torch.float8_e4m3fn: + compute_type = tl.bfloat16 + else: + raise ValueError( + f"Unsupported compute_type: {hidden_states.dtype}") + + #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}") + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N)) + intermediate_cache2 = _resize_cache(workspace2, + (E, num_tokens, N // 2)) + intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K)) + + # MM1 + invoke_moe_batched_triton_kernel(A=hidden_states, + B=w1, + C=intermediate_cache1, + expert_num_tokens=expert_num_tokens, + compute_type=compute_type, + A_scale=a1q_scale, + B_scale=w1_scale, + B_zp=w1_zp, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + config=config, + block_shape=self.block_shape) + + # Fix activations + assert activation == "silu" + invoke_batched_silu_and_mul(output=intermediate_cache2, + input=intermediate_cache1, + expert_num_tokens=expert_num_tokens) + + qintermediate_cache2 = intermediate_cache2 + a2q_scale = a2_scale + # TODO (varun) : support w8a8 + assert not self.use_fp8_w8a8 + #if self.use_fp8_w8a8: + # qintermediate_cache2, a2q_scale = _fp8_quantize( + # intermediate_cache2, a2_scale, self.block_shape) + + invoke_moe_batched_triton_kernel(A=intermediate_cache2, + B=w2, + C=intermediate_cache3, + expert_num_tokens=expert_num_tokens, + compute_type=compute_type, + A_scale=a2q_scale, + B_scale=w2_scale, + B_zp=w2_zp, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + config=config, + block_shape=self.block_shape) + + return intermediate_cache3 diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a125715486df..d5f2b165e8b4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -29,7 +29,8 @@ if current_platform.is_cuda_alike(): from .dispatch_combine import StandardDispatchCombine - from .fused_moe import TritonExperts, BatchedDispatchCombine, BatchedExperts, fused_experts + from .fused_moe import TritonExperts, fused_experts + from .fused_batched_moe import BatchedDispatchCombine, BatchedTritonExperts from .modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine from .pplx_dispatch_combine import PplxDispatchCombine else: @@ -117,7 +118,8 @@ def get_or_create(self, **kwargs): with self._lock: instance = self._cache.get(key) if instance is None: - instance = pplx.AllToAll(**kwargs) + # TODO: should be intranode + instance = pplx.AllToAll.internode(**kwargs) self._cache[key] = instance return instance @@ -245,8 +247,14 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): - logger.info(f"BatchedExperts {self.moe}") - experts = BatchedExperts() + logger.info(f"BatchedTritonExperts {self.moe}") + experts = BatchedTritonExperts( + use_fp8_w8a8 = False, + use_int8_w8a8 = False, + use_int8_w8a16 = False, + use_int4_w4a16 = False, + block_shape = None, + ) else: logger.info(f"TritonExperts {self.moe}") experts = TritonExperts( @@ -255,7 +263,6 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine use_int8_w8a16 = False, use_int4_w4a16 = False, block_shape = None, - block_m = None, #block_m, per_channel_quant = False, ) @@ -1037,10 +1044,12 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, max=moe_dp_chunk_size_per_rank), dim=0) - hidden_states = self.naive_multicast( - hidden_states, cu_tokens_across_dp_this_iter) - router_logits = self.naive_multicast( - router_logits, cu_tokens_across_dp_this_iter) + # TODO: still may be needed for non-pplx, put into dispatcher class. + if False: + hidden_states = self.naive_multicast( + hidden_states, cu_tokens_across_dp_this_iter) + router_logits = self.naive_multicast( + router_logits, cu_tokens_across_dp_this_iter) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -1060,7 +1069,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, activation=self.activation, ) - if self.dp_size > 1: + # TODO: needed for non-pplx? + if False and self.dp_size > 1: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[ self.dp_rank - 1] end = cu_tokens_across_dp_this_iter[self.dp_rank] @@ -1069,7 +1079,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + # TODO: needed for non-pplx? + if False and self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -1090,8 +1101,14 @@ def update_chunk_bound(x: int): return min(x + moe_dp_chunk_size_per_rank, full_hidden_states.shape[0]) - chunk_start = update_chunk_bound(chunk_start) - chunk_end = update_chunk_bound(chunk_end) + #chunk_start = update_chunk_bound(chunk_start) + #chunk_end = update_chunk_bound(chunk_end) + if chunk_end == full_hidden_states.shape[0]: + # simply redo computation + pass + else: + chunk_start = update_chunk_bound(chunk_start) + chunk_end = update_chunk_bound(chunk_end) return full_final_hidden_states @@ -1099,7 +1116,8 @@ def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None - if self.dp_size > 1: + # TODO: still may be needed for non-pplx + if False and self.dp_size > 1: ctx = get_forward_context() cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu @@ -1127,7 +1145,8 @@ def forward_impl(self, hidden_states: torch.Tensor, apply_router_weight_on_input=self.apply_router_weight_on_input, ) - if self.dp_size > 1: + # TODO: needed for non-pplx? + if False and self.dp_size > 1: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ self.dp_rank - 1] end = cu_tokens_across_dp_cpu[self.dp_rank] @@ -1135,7 +1154,8 @@ def forward_impl(self, hidden_states: torch.Tensor, all_hidden_states = get_dp_group().all_reduce(final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + # TODO: needed for non-pplx? + if False and self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index ef5da7a5d9e3..576c454ec31d 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -46,7 +46,8 @@ def dispatch( device = a1.device hidden_dim = a1.shape[-1] # K - assert expert_map is None, "NYI" + # ?? + # assert expert_map is None, "NYI" if apply_router_weight_on_input: topk = rank_topk_ids.shape[1] @@ -96,11 +97,8 @@ def dispatch( ) # This argument is optional, defaults to indices.shape[0] - # This causes a deadlock? - #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens - #num_tokens = a1.shape[0] # M - #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) - bound_m = None + num_tokens = a1.shape[0] # M + bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) # TODO: optimize this? indices = rank_topk_ids.to(dtype=torch.uint32) @@ -125,11 +123,9 @@ def combine( apply_router_weight_on_input: bool, ) -> None: # This argument is optional - #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens - #num_tokens = fused_expert_output.shape[0] # M - #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, - # device=fused_expert_output.device) - bound_m = None + num_tokens = output.shape[0] # M + bound_m = torch.tensor([num_tokens], dtype=torch.uint32, + device=fused_expert_output.device) assert output.shape[0] <= self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index ffa5840b4604..350e3a592178 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -171,7 +171,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # See DeepseekV2DecoderLayer for more details. final_hidden_states = final_hidden_states + shared_output \ * (1. / self.routed_scaling_factor) - if self.tp_size > 1: + + # TODO: check if needed for non-pplx? + if False and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) From 800dde167b95f7ee114f3b8547f8d87559559b43 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 22:23:35 +0000 Subject: [PATCH 157/171] varun's fixes Signed-off-by: Bill Nell --- .../layers/fused_moe/pplx_dispatch_combine.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 576c454ec31d..f88044da0201 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -123,9 +123,10 @@ def combine( apply_router_weight_on_input: bool, ) -> None: # This argument is optional - num_tokens = output.shape[0] # M - bound_m = torch.tensor([num_tokens], dtype=torch.uint32, - device=fused_expert_output.device) + #num_tokens = output.shape[0] # M + #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, + # device=fused_expert_output.device) + bound_m = None assert output.shape[0] <= self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] From 918e62b4d007fcb90c0dee1af362edb380bfb6f9 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 22:25:51 +0000 Subject: [PATCH 158/171] tweak bound_m Signed-off-by: Bill Nell --- .../layers/fused_moe/pplx_dispatch_combine.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index f88044da0201..576c454ec31d 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -123,10 +123,9 @@ def combine( apply_router_weight_on_input: bool, ) -> None: # This argument is optional - #num_tokens = output.shape[0] # M - #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, - # device=fused_expert_output.device) - bound_m = None + num_tokens = output.shape[0] # M + bound_m = torch.tensor([num_tokens], dtype=torch.uint32, + device=fused_expert_output.device) assert output.shape[0] <= self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] From b6ae86151537e9c528c706f751f155dd73b647e6 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 22:59:42 +0000 Subject: [PATCH 159/171] run linter Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 134 ++-- tests/kernels/moe/test_moe.py | 4 +- tests/kernels/quantization/test_block_fp8.py | 5 +- tests/kernels/test_block_fp8.py | 499 ------------- tests/kernels/test_pplx_moe.py | 654 ------------------ vllm/forward_context.py | 9 +- .../layers/fused_moe/deep_gemm_moe.py | 4 +- .../layers/fused_moe/fused_batched_moe.py | 500 ++++++------- .../layers/fused_moe/fused_moe.py | 241 ++----- vllm/model_executor/layers/fused_moe/layer.py | 97 +-- .../layers/fused_moe/modular_kernel.py | 4 +- .../layers/fused_moe/pplx_dispatch_combine.py | 12 +- .../layers/fused_moe/triton_deep_gemm_moe.py | 40 +- .../model_executor/layers/quantization/fp8.py | 24 +- 14 files changed, 471 insertions(+), 1756 deletions(-) delete mode 100644 tests/kernels/test_block_fp8.py delete mode 100644 tests/kernels/test_pplx_moe.py diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index ffd69935b461..1bb8f4e09ddf 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -1,15 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 -import torch -import triton -import triton.language as tl +from dataclasses import dataclass import pytest -from dataclasses import dataclass +import torch +import triton.language as tl from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - invoke_moe_batched_triton_kernel, - invoke_batched_silu_and_mul) + invoke_batched_silu_and_mul, invoke_moe_batched_triton_kernel) @dataclass @@ -20,25 +18,36 @@ class BatchedMMConfig: K: int N: int + @dataclass class BatchedMMTensors: A: torch.Tensor # [E, max_tokens, K] B: torch.Tensor # [E, K, N] - column major C: torch.Tensor # [E, max_tokens, N] - num_expert_tokens: torch.Tensor # [E] + num_expert_tokens: torch.Tensor # [E] @staticmethod def make_tensors(config: BatchedMMConfig): - A = torch.randn((config.num_experts, config.max_tokens_per_expert, config.K), device="cuda", dtype=config.dtype) / 50.0 - B = torch.randn((config.num_experts, config.N, config.K), device="cuda", dtype=config.dtype) / 50.0 - C = torch.zeros((config.num_experts, config.max_tokens_per_expert, config.N), device="cuda", dtype=config.dtype) - num_expert_tokens=torch.randint(low = 0, high = config.max_tokens_per_expert, size=(config.num_experts,), device="cuda", dtype=torch.int32) - return BatchedMMTensors(A,B,C, num_expert_tokens) - - -def ref_impl(A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, + A = torch.randn( + (config.num_experts, config.max_tokens_per_expert, config.K), + device="cuda", + dtype=config.dtype) / 50.0 + B = torch.randn((config.num_experts, config.N, config.K), + device="cuda", + dtype=config.dtype) / 50.0 + C = torch.zeros( + (config.num_experts, config.max_tokens_per_expert, config.N), + device="cuda", + dtype=config.dtype) + num_expert_tokens = torch.randint(low=0, + high=config.max_tokens_per_expert, + size=(config.num_experts, ), + device="cuda", + dtype=torch.int32) + return BatchedMMTensors(A, B, C, num_expert_tokens) + + +def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, num_expert_tokens: torch.Tensor) -> torch.Tensor: num_expert_tokens_cpu = num_expert_tokens.clone() @@ -49,19 +58,16 @@ def ref_impl(A: torch.Tensor, num_tokens = num_expert_tokens_cpu[e] C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) - return C + @pytest.mark.parametrize("num_experts", [16, 32]) @pytest.mark.parametrize("max_tokens_per_expert", [512]) @pytest.mark.parametrize("K", [256]) @pytest.mark.parametrize("N", [512]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_batched_mm(num_experts: int, - max_tokens_per_expert: int, - K: int, - N: int, - dtype: torch.dtype): +def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, + N: int, dtype: torch.dtype): config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) tensors = BatchedMMTensors.make_tensors(config) @@ -69,29 +75,33 @@ def test_batched_mm(num_experts: int, test_output = tensors.C ref_output = test_output.clone() - - compute_tl_dtype = {torch.float16 : tl.float16, - torch.bfloat16 : tl.bfloat16, - torch.float32 : tl.float32}[test_output.dtype] - invoke_moe_batched_triton_kernel(tensors.A, - tensors.B, - test_output, - tensors.num_expert_tokens, - compute_tl_dtype, - # Quantization data - None, - None, - None, - # Quantization schemes - False, - False, - False, - config = {"BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 16}) - - - ref_output = ref_impl(tensors.A, tensors.B, ref_output, tensors.num_expert_tokens) + compute_tl_dtype = { + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, + torch.float32: tl.float32 + }[test_output.dtype] + invoke_moe_batched_triton_kernel( + tensors.A, + tensors.B, + test_output, + tensors.num_expert_tokens, + compute_tl_dtype, + # Quantization data + None, + None, + None, + # Quantization schemes + False, + False, + False, + config={ + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 16 + }) + + ref_output = ref_impl(tensors.A, tensors.B, ref_output, + tensors.num_expert_tokens) #torch.cuda.synchronize() #print (f"ref output {ref_output}") #print (f"test output {test_output}") @@ -106,6 +116,7 @@ class BatchedSiluMulConfig: max_tokens_per_expert: int D: int + @dataclass class BatchedSiluMulTensors: input: torch.Tensor @@ -114,16 +125,24 @@ class BatchedSiluMulTensors: @staticmethod def make_tensors(config: BatchedSiluMulConfig): - input = torch.randn((config.num_experts, config.max_tokens_per_expert, config.D * 2), device="cuda", dtype=config.dtype) / 50.0 - output = torch.zeros((config.num_experts, config.max_tokens_per_expert, config.D), device="cuda", dtype=config.dtype) - num_expert_tokens=torch.randint(low = 0, high = config.max_tokens_per_expert, size=(config.num_experts,), device="cuda", dtype=torch.int32) + input = torch.randn( + (config.num_experts, config.max_tokens_per_expert, config.D * 2), + device="cuda", + dtype=config.dtype) / 50.0 + output = torch.zeros( + (config.num_experts, config.max_tokens_per_expert, config.D), + device="cuda", + dtype=config.dtype) + num_expert_tokens = torch.randint(low=0, + high=config.max_tokens_per_expert, + size=(config.num_experts, ), + device="cuda", + dtype=torch.int32) return BatchedSiluMulTensors(input, output, num_expert_tokens) -def ref_batched_silu_mul( - output: torch.Tensor, - input: torch.Tensor, - num_expert_tokens: torch.Tensor) -> torch.Tensor: +def ref_batched_silu_mul(output: torch.Tensor, input: torch.Tensor, + num_expert_tokens: torch.Tensor) -> torch.Tensor: num_expert_tokens_cpu = num_expert_tokens.clone() num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") @@ -140,10 +159,8 @@ def ref_batched_silu_mul( @pytest.mark.parametrize("max_tokens_per_expert", [128]) @pytest.mark.parametrize("D", [128, 256]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_batched_silu_mul(num_experts: int, - max_tokens_per_expert: int, - D: int, - dtype: torch.dtype): +def test_batched_silu_mul(num_experts: int, max_tokens_per_expert: int, D: int, + dtype: torch.dtype): config = BatchedSiluMulConfig(dtype, num_experts, max_tokens_per_expert, D) tensors = BatchedSiluMulTensors.make_tensors(config) @@ -153,6 +170,7 @@ def test_batched_silu_mul(num_experts: int, ref_batched_silu_mul(ref_out, tensors.input, tensors.expert_num_tokens) - invoke_batched_silu_and_mul(test_out, tensors.input, tensors.expert_num_tokens) + invoke_batched_silu_and_mul(test_out, tensors.input, + tensors.expert_num_tokens) torch.testing.assert_close(test_out, ref_out) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index acf3636e77b9..6822b88300e1 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -15,8 +15,7 @@ torch_moe_single) from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -26,7 +25,6 @@ from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -from vllm.model_executor.layers.activation import SiluAndMul NUM_EXPERTS = [8, 64] EP_SIZE = [1, 4] diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 78200666378e..28e73fdd7c0e 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) + _valid_deep_gemm_shape, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -437,8 +437,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) - out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) + out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py deleted file mode 100644 index 762d02394086..000000000000 --- a/tests/kernels/test_block_fp8.py +++ /dev/null @@ -1,499 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Adapted from https://github.com/sgl-project/sglang/pull/2575 -import itertools - -import pytest -import torch - -from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8) -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8, w8a8_block_fp8_matmul) -from vllm.platforms import current_platform - -dg_available = False -try: - import deep_gemm - dg_available = True -except ImportError: - pass - -if current_platform.get_device_capability() < (9, 0): - pytest.skip("FP8 Triton requires CUDA 9.0 or higher", - allow_module_level=True) - -# Test configurations -DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] -NUM_TOKENS = [7, 83, 2048] -D = [512, 4096, 5120, 13824] -GROUP_SIZE = [64, 128, 256, 512] -M = [1, 7, 8, 83, 84, 512, 2048, 4096] -N = [128, 512, 1024, 4096, 7168, 7748, 13824] -K = [256, 4096, 5120, 3884, 13824, 16384] -# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 -# and its hidden size is 7168. -M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [1, 128, 192, 512, 1335, 2048] -N_moe = [128, 256, 4608] # [13824] -K_moe = [256, 512, 7168] # [13824] -BLOCK_SIZE = [[128, 128]] -E = [2, 8, 16, 24] # [128, 256] -TOP_KS = [1, 2, 6] -OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] -SEEDS = [0] - - -def native_per_token_group_quant_fp8(x, - group_size, - eps=1e-10, - dtype=torch.float8_e4m3fn): - """Function to perform per-token-group quantization on an input tensor - `x` using native torch.""" - assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot " - "be divisible by `group_size`") - assert x.is_contiguous(), "`x` is not contiguous" - - finfo = torch.finfo(dtype) - fp8_min = finfo.min - fp8_max = finfo.max - - x_ = x.reshape(x.numel() // group_size, group_size) - amax = x_.abs().max(dim=-1, - keepdim=True)[0].clamp(min=eps).to(torch.float32) - x_s = amax / fp8_max - x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) - x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) - - return x_q, x_s - - -def native_w8a8_block_fp8_matmul(A, - B, - As, - Bs, - block_size, - output_dtype=torch.float16): - """Matrix multiplication with block-wise quantization using native torch.""" - A = A.to(torch.float32) - B = B.to(torch.float32) - assert A.shape[-1] == B.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] - assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] - assert A.shape[:-1] == As.shape[:-1] - - M = A.numel() // A.shape[-1] - N, K = B.shape - origin_C_shape = A.shape[:-1] + (N, ) - A = A.reshape(M, A.shape[-1]) - As = As.reshape(M, As.shape[-1]) - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k - assert n_tiles == Bs.shape[0] - assert k_tiles == Bs.shape[1] - - C_shape = (M, N) - C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) - - A_tiles = [ - A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) - ] - B_tiles = [[ - B[ - j * block_n:min((j + 1) * block_n, N), - i * block_k:min((i + 1) * block_k, K), - ] for i in range(k_tiles) - ] for j in range(n_tiles)] - C_tiles = [ - C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) - ] - As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] - - for i in range(k_tiles): - for j in range(n_tiles): - a = A_tiles[i] - b = B_tiles[j][i] - c = C_tiles[j] - s = As_tiles[i] * Bs[j][i] - c[:, :] += torch.matmul(a, b.t()) * s - - C = C.reshape(origin_C_shape).to(output_dtype) - return C - - -def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): - """Fused moe with block-wise quantization using native torch.""" - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - - _, block_k = block_shape[0], block_shape[1] - a_q, a_s = native_per_token_group_quant_fp8(a, block_k) - a_q = a_q.to(torch.float32) - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - inter_out = native_w8a8_block_fp8_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_fp8( - act_out, block_k) - act_out = act_out.to(torch.float32) - out[mask] = native_w8a8_block_fp8_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - - -# Skip all tests if CUDA is not available -pytest.importorskip("torch.cuda") - - -@pytest.fixture(autouse=True) -def setup_cuda(): - torch.set_default_device("cuda") - - -@pytest.mark.parametrize( - "num_tokens,d,dtype,group_size,seed", - itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS)) -@torch.inference_mode() -def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): - torch.manual_seed(seed) - x = torch.rand(num_tokens, d, dtype=dtype) - - ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size) - out, scale = per_token_group_quant_fp8(x, group_size) - - assert torch.allclose(out.to(torch.float32), - ref_out.to(torch.float32), - rtol=0.15) - assert torch.allclose(scale, ref_scale) - - -@pytest.mark.parametrize( - "M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) -@torch.inference_mode() -def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): - torch.manual_seed(seed) - factor_for_scale = 1e-2 - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - block_n, block_k = block_size[0], block_size[1] - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k - - As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale - Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale - - ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) - out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.001 - - -@pytest.mark.parametrize( - "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) -@torch.inference_mode() -def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): - if topk > E: - pytest.skip(f"Skipping test; topk={topk} > E={E}") - - torch.manual_seed(seed) - factor_for_scale = 1e-2 - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - vllm_config = VllmConfig() - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_bf16 = (torch.rand( - (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - del w1_bf16 - - w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - del w2_bf16 - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = (2 * N + block_n - 1) // block_n - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - k_tiles_w2 = (N + block_k - 1) // block_k - - w1_s = torch.rand( - (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale - w2_s = torch.rand( - (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale - - score = torch.randn((M, E), dtype=dtype) - - # Set the context to avoid lots of warning spam. - with set_current_vllm_config(vllm_config): - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) - - #print(f"{out.sum()=}") - #print(f"{ref_out.sum()=}") - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.03 - - -def per_block_cast_to_fp8( - x: torch.Tensor, - block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros( - (deep_gemm.ceil_div(m, 128) * 128, - deep_gemm.ceil_div(n, block_size_n) * block_size_n), - dtype=x.dtype, - device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() - scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - return x_scaled_sub, scales - - -@pytest.mark.parametrize( - "M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) -@torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): - # only aligned sizes - if M % 4 != 0 or K % 128 != 0 or N % 64 != 0: - pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") - - torch.manual_seed(seed) - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max = fp8_info.max - - A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - - _, block_k = block_size[0], block_size[1] - - A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k) - B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32) - - As = As_fp8.to(torch.float32) - Bs = Bs_fp8.to(torch.float32) - - ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) - - # Transpose earlier so that the testing will not trigger transposing kernels - As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8) - - out = torch.zeros((M, N), device='cuda', dtype=out_dtype) - - assert As_fp8.shape == (M, (K + 127) // - 128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}" - - deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.001 - - -def fp8_perm(m, idx): - if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: - return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) - else: - return m[idx, ...] - - -def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): - M, K = a.shape - - sorted_token_ids, m_indices, num_pad = moe_align_block_size( - topk_ids, block_m, num_groups, None, pad_sorted_ids=True) - - num_tokens = topk * M - - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) - inv_perm = torch.argsort(sorted_token_ids)[:M * topk] - - a = fp8_perm(a, sorted_token_ids // topk) - if a_s is not None: - a_s = a_s[sorted_token_ids // topk] - - return a, a_s, m_indices, inv_perm - - -def _moe_unpermute(out, inv_perm, topk, K, topk_weight): - M = topk_weight.shape[0] - out = out[inv_perm, ...] - tmp_out = out.view(-1, topk, K) - return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - -def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, - block_shape): - """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" - num_groups = w1.shape[0] - M, K = a.shape - N = w2.shape[-1] - - topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) - - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - - _, block_k = block_shape[0], block_shape[1] - - a_q, a_s = per_token_group_quant_fp8(a, block_m) - - a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids, - num_groups, topk, block_m) - - inter_out = torch.zeros((a_q.shape[0], N * 2), - dtype=torch.bfloat16, - device=a.device) - - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), - inter_out, m_indices) - - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - - out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device) - - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - - final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight) - - return final_out - - -@pytest.mark.parametrize( - "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) -@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") -@torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, - dtype, seed): - - if topk > E: - pytest.skip(f"Skipping test: topk={topk} > E={E}") - - if not _valid_deep_gemm_shape(M, N, K): - pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") - - vllm_config = VllmConfig() - - torch.manual_seed(seed) - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * - fp8_max).clamp(min=fp8_min, max=fp8_max) - - w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * - fp8_max).clamp(min=fp8_min, max=fp8_max) - - score = torch.randn((M, E), dtype=dtype) - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = ((2 * N) + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w2 = (N + block_k - 1) // block_k - - w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) - w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) - - w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) - w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) - - w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() - w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() - - assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128) - assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] - - for i in range(E): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - - # Set the context to avoid lots of warning spam. - with set_current_vllm_config(vllm_config): - if M >= 128: - ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, - score, topk, block_size) - else: - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) - - topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) - - out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) - - #print(f"{out.sum()=}") - #print(f"{ref_out.sum()=}") - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.03 diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py deleted file mode 100644 index 97fc74e3bd3c..000000000000 --- a/tests/kernels/test_pplx_moe.py +++ /dev/null @@ -1,654 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Tests for the MOE layers. - -Run `pytest tests/kernels/test_pplx_moe.py`. -""" -import dataclasses -import os -import pytest -import torch -import traceback - -from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] -from typing import Callable, Concatenate, Optional, ParamSpec, Tuple - -from pplx_kernels import AllToAll -from pplx_kernels.nvshmem import ( - nvshmem_alloc_empty_unique_id, - nvshmem_finalize, - nvshmem_get_unique_id, - nvshmem_init, -) - -import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, - torch_moe, torch_moe_single) -from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, moe_align_block_size) -from vllm.platforms import current_platform - -from vllm.model_executor.layers.activation import SiluAndMul - -from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts, BatchedDispatchCombine, BatchedExperts, fused_experts -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel -from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import PplxDispatchCombine - -NUM_EXPERTS = [8, 64] -EP_SIZE = [1, 4] -TOP_KS = [2, 6] - -P = ParamSpec("P") - -require_multi_node = pytest.mark.skipif( - "MASTER_ADDR" not in os.environ, - reason="Requires multi-node environment", -) - - -@dataclasses.dataclass -class ProcessGroupInfo: - world_size: int - world_local_size: int - rank: int - node_rank: int - local_rank: int - device: torch.device - - -def _worker_parallel_launch( - local_rank: int, - world_size: int, - world_local_size: int, - node_rank: int, - init_method: str, - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - rank = node_rank * world_local_size + local_rank - torch.cuda.set_device(local_rank) - device = torch.device("cuda", local_rank) - torch.distributed.init_process_group( - backend="cpu:gloo,cuda:nccl", - init_method=init_method, - rank=rank, - world_size=world_size, - device_id=device, - ) - barrier = torch.tensor([rank], device=device) - torch.distributed.all_reduce(barrier) - - try: - worker( - ProcessGroupInfo( - world_size=world_size, - world_local_size=world_local_size, - rank=rank, - node_rank=node_rank, - local_rank=local_rank, - device=device, - ), - *args, - **kwargs, - ) - except Exception as ex: - print(ex) - traceback.print_exception(ex) - raise - finally: - torch.distributed.destroy_process_group() - - -def parallel_launch( - world_size: int, - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - assert not kwargs - spawn( - _worker_parallel_launch, - args=( - world_size, - world_size, - 0, - "tcp://localhost:29500", - worker, - ) - + args, - nprocs=world_size, - join=True, - ) - - -def parallel_launch_from_env( - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - """ - Launches a worker function in parallel across all processes in the current - environment. The environment must have the following variables set: - - WORLD_SIZE: The total number of processes. - - WORLD_LOCAL_SIZE: The number of processes on the current node. - - NODE_RANK: The rank of the current - - MASTER_ADDR: The address of the master process. - - MASTER_PORT: The port of the master process. - """ - assert not kwargs - world_size = int(os.environ["WORLD_SIZE"]) - world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) - node_rank = int(os.environ["NODE_RANK"]) - assert "MASTER_ADDR" in os.environ - assert "MASTER_PORT" in os.environ - spawn( - _worker_parallel_launch, - args=( - world_size, - world_local_size, - node_rank, - "env://", - worker, - ) - + args, - nprocs=world_local_size, - join=True, - ) - - -def torch_dispatch( - a: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - max_num_tokens: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - assert topk_ids.dim() == 2 - assert topk_ids.shape[0] == a.shape[0] - - num_tokens = a.shape[0] - topk = topk_ids.shape[1] - - tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - if max_num_tokens is None: - max_num_tokens = tokens_per_expert.max() - - b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), - dtype=a.dtype, device=a.device) - - #print(f"b_a shape {b_a.shape}") - - token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) - - for token in range(num_tokens): - for j in range(topk): - expert_id = topk_ids[token, j] - idx = token_counts[expert_id] - b_a[expert_id, idx:idx+1, :] = a[token, :] - token_counts[expert_id] = token_counts[expert_id] + 1 - - return b_a, tokens_per_expert - - -def torch_combine(b_out, topk_weight, topk_ids): - num_tokens, topk = topk_ids.shape - num_experts = b_out.shape[0] - K = b_out.shape[-1] - out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) - expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) - for token in range(num_tokens): - expert_ids = topk_ids[token] - for i in range(expert_ids.numel()): - expert_id = expert_ids[i] - idx = expert_counts[expert_id] - out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i] - expert_counts[expert_id] = expert_counts[expert_id] + 1 - - return out - - -def torch_batched_moe(a, w1, w2, topk_weight, topk_ids): - num_experts = w1.shape[0] - b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts) - assert b_a.dim() == 3 - num_tokens, topk = topk_ids.shape - _, max_num_tokens, K = b_a.shape - assert num_experts == b_a.shape[0] and K == w2.shape[1] - out = torch.zeros((num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device) - tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device) - for expert in range(num_experts): - num = tokens_per_expert[expert] - if num > 0: - torch.ops._C.silu_and_mul(tmp[:num], b_a[expert,:num,:] @ w1[expert].transpose(0, 1)) - out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) - - return torch_combine(out, topk_weight, topk_ids) - - -# TODO: same as torch_moe but with fused_topk factored out. -def torch_moe2(a, w1, w2, topk_weight, topk_ids): - M, K = a.shape - topk = topk_ids.shape[1] - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) - num_experts = w1.shape[0] - for i in range(num_experts): - mask = (topk_ids == i).view(-1) - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - -@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 511, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_fused_moe_batched_experts( - m: int, - n: int, - k: int, - e: int, - topk: int, - dtype: torch.dtype, -): - current_platform.seed_everything(7) - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) - - vllm_config = VllmConfig() - with set_current_vllm_config(vllm_config): - topk_weight, topk_ids = fused_topk(a, score, topk, False) - - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - - if True: - triton_output = torch_batched_moe(a, - w1, - w2, - topk_weight, - topk_ids) - else: - b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) - triton_output = fused_batched_experts( - b_a, - w1, - w2, - topk_weight, - topk_ids, - global_num_experts=e - ) - - if False: - torch.set_printoptions(profile="full") - print("BASELINE") - print(torch_output) - print("OUTPUT") - print(triton_output) - - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) - - -def rank_chunk(num, r, w): - rem = num % w - return (num // w) + (1 if r < rem else 0) - - -def chunk_by_rank(t, r, w): - chunk = rank_chunk(t.shape[0], r, w) - #print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") - return t[(r * chunk):(r + 1)*chunk] - - -def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): - assert torch.cuda.current_device() == pgi.local_rank - - num_tokens, hidden_dim = a.shape - num_experts = w1.shape[0] - block_size = 128 - device = pgi.device - rank = pgi.rank - world_size = pgi.world_size - rank_num_tokens = rank_chunk(num_tokens, rank, world_size) - max_num_tokens = num_tokens - #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") - - ata = AllToAll( - max_num_tokens=max_num_tokens, - num_experts=num_experts, - experts_per_token=topk, - rank=rank, - world_size=pgi.world_size, - dp_size=dp_size, - hidden_dim=hidden_dim, - hidden_dim_bytes=hidden_dim * a.dtype.itemsize, - hidden_dim_scale_bytes=( - 0 - if a.dtype.itemsize != 1 - else ( - (hidden_dim + block_size - 1) - // block_size - * torch.float32.itemsize - ) - ), - ) - - dispatch_combine = PplxDispatchCombine( - ata, - max_num_tokens, - pgi.world_size, - dp_size, - rank, - a.dtype, - ) - - a_chunk = chunk_by_rank(a, rank, world_size).to(device) - score_chunk = chunk_by_rank(scores, rank, world_size).to(device) - chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - - #print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}") - - b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( - a_chunk, - None, - None, - chunk_topk_ids, - num_experts, # store at PplxDispatchCombine creation? - None - ) - - #topk_weight, topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, num_experts) - - torch.distributed.all_reduce(tokens_per_expert) - #max_num = tokens_per_expert.max() - tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32) - - #print(f"tpe {tokens_per_expert}") - #print(f"ent {expert_num_tokens}") - - #torch.set_printoptions(profile="full") - #torch.distributed.all_reduce(naive_b_a, op=torch.distributed.ReduceOp.MAX) - #torch.distributed.broadcast(naive_b_a, src=rank) - - #naive_b_a = chunk_by_rank(naive_b_a, rank, world_size) - - #print("b_a", b_a.shape, b_a) #[:, :naive_b_a.shape[1]]) - #print("naive_b_a", naive_b_a.shape, naive_b_a) - - torch.testing.assert_close(tokens_per_expert, expert_num_tokens, atol=0, rtol=0) - #torch.testing.assert_close(b_a[:, :naive_b_a.shape[1]], naive_b_a, atol=2e-2, rtol=0) - - b_a = b_a * 1.5 - - out = torch.full( - (rank_num_tokens * world_size, hidden_dim), - torch.nan, - dtype=a.dtype, - device=device, - ) - - dispatch_combine.combine( - out, - b_a, - chunk_topk_weight, - chunk_topk_ids, - ) - torch.cuda.synchronize() - - ata.destroy() - - #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") - - #torch.distributed.all_reduce(out) - - #print(f"AR OUT {rank}: {out.shape} {out}") - - return out[:rank_num_tokens] - - -def _pplx_dispatch_combine( - pgi: ProcessGroupInfo, - dp_size: int, - m, n, k, e, - #a: torch.Tensor, - #w1: torch.Tensor, - #w2: torch.Tensor, - #score: torch.Tensor, - topk: int, - dtype: torch.dtype, -): - uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) - device = pgi.device - - a = torch.randn((m, k), device=device, dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 - score = torch.randn((m, e), device=device, dtype=dtype) - - #m, k = a.shape - #e, _, n = w2.shape - - topk_weight, topk_ids = fused_topk(a, score, topk, False) - - #print(f"a {a.shape}") - a_rep = torch.repeat_interleave(a, topk, dim=0) - #print(f"a_rep {a_rep.shape} {a_rep.view(-1, topk, k)}") - - torch_output = (a_rep.view(-1, topk, k) * 1.5 * topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) - - #print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}") - - pplx_output = torch_pplx_dispatch_combine(pgi, - dp_size, - a, - w1, - w2, - score, - topk) - - if False: - torch.set_printoptions(profile="full") - print("BASELINE") - print(torch_output) - print("OUTPUT") - print(pplx_output) - - torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) - - torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) - - nvshmem_finalize() - - -@pytest.mark.parametrize("m", [4, 32, 64, 222]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions? % 128? -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) -def test_pplx_dispatch_combine( - m: int, - n: int, - k: int, - e: int, - topk: int, - dtype: torch.dtype, - world_dp_size: Tuple[int, int], -): - current_platform.seed_everything(7) - world_size, dp_size = world_dp_size - - parallel_launch( - #world_size, _pplx_dispatch_combine, dp_size, a, w1, w2, score, topk, dtype - world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, topk, dtype - ) - - -def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): - assert torch.cuda.current_device() == pgi.local_rank - - num_tokens, hidden_dim = a.shape - num_experts = w1.shape[0] - block_size = 128 - device = pgi.device - rank = pgi.rank - world_size = pgi.world_size - rank_num_tokens = rank_chunk(num_tokens, rank, world_size) - max_num_tokens = num_tokens - - ata = AllToAll( - max_num_tokens=max_num_tokens, - num_experts=num_experts, - experts_per_token=topk, - rank=rank, - world_size=pgi.world_size, - dp_size=dp_size, - hidden_dim=hidden_dim, - hidden_dim_bytes=hidden_dim * a.dtype.itemsize, - hidden_dim_scale_bytes=( - 0 - if a.dtype.itemsize != 1 - else ( - (hidden_dim + block_size - 1) - // block_size - * torch.float32.itemsize - ) - ), - ) - - w1 = w1.to(device) - w2 = w2.to(device) - - dispatch_combine = PplxDispatchCombine( - ata, - max_num_tokens, - pgi.world_size, - dp_size, - rank, - a.dtype, - ) - - experts = BatchedExperts(rank, pgi.world_size, max_num_tokens) - - fused_experts = FusedMoEModularKernel( - dispatch_combine, - experts, - ) - - a_chunk = chunk_by_rank(a, rank, world_size).to(device) - score_chunk = chunk_by_rank(scores, rank, world_size).to(device) - chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - - #print(f"chunk_topk_ids {rank} {chunk_topk_ids.shape} {chunk_topk_ids.view(-1)}") - - out = fused_experts( - a_chunk, - # Chunking weights like this only works for batched format - chunk_by_rank(w1, rank, world_size), - chunk_by_rank(w2, rank, world_size), - #w1, - #w2, - chunk_topk_weight, - chunk_topk_ids, - global_num_experts=num_experts #? num_local_experts? - ) - - torch.cuda.synchronize() - - ata.destroy() - - #print(f"OUT {rank}: {out.shape} {out}") - - return out[:rank_num_tokens] - - -def _pplx_moe( - pgi: ProcessGroupInfo, - dp_size: int, - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - score: torch.Tensor, - topk: int, - dtype: torch.dtype, -): - uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) - - m, k = a.shape - e, _, n = w2.shape - - torch.set_printoptions(profile="full") - - vllm_config = VllmConfig() - with set_current_vllm_config(vllm_config): - topk_weight, topk_ids = fused_topk(a, score, topk, False) - - #print(f"topk_ids {pgi.rank} {topk_ids.shape} {topk_ids.view(-1)}") - - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - - pplx_output = torch_pplx_moe(pgi, - dp_size, - a, - w1, - w2, - score, - topk) - - if False: - print("BASELINE") - print(torch_output) - print("OUTPUT") - print(pplx_output) - - torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) - - #print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}") - - torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) - - nvshmem_finalize() - - -# TODO: M == 1 doesn't work -@pytest.mark.parametrize("m", [2, 3, 32, 45, 64, 222]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128, 1024])# , 2048]) -@pytest.mark.parametrize("k", [128, 512]) # , 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) -def test_pplx_moe( - m: int, - n: int, - k: int, - e: int, - topk: int, - dtype: torch.dtype, - world_dp_size: Tuple[int, int], -): - current_platform.seed_everything(7) - world_size, dp_size = world_dp_size - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) - - parallel_launch( - world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, dtype - #world_size, _pplx_moe, dp_size, m, n, k, e, topk, dtype - ) - diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 99c5bba8bc47..7f6ad5c261cc 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -93,16 +93,15 @@ def set_forward_context(attn_metadata: Any, from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) #TODO device? - max_tokens_across_dp = torch.max(num_tokens_tensor) #.to(device="cuda") + max_tokens_across_dp = torch.max( + num_tokens_tensor) #.to(device="cuda") cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) dp_rank_num_tokens = torch.tensor( [num_tokens], dtype=torch.uint32, device=vllm_config.device_config.device) - dp_metadata = DPMetadata(max_tokens_across_dp, - num_tokens_tensor, - cu_tokens_across_dp_cpu, - dp_rank_num_tokens) + dp_metadata = DPMetadata(max_tokens_across_dp, num_tokens_tensor, + cu_tokens_across_dp_cpu, dp_rank_num_tokens) global _forward_context prev_context = _forward_context diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index a694c53d9f36..266ba3bfa07a 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -134,9 +134,7 @@ def apply( dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids) - self.activation(activation, - workspace2, - workspace1.view(-1, N)) + self.activation(activation, workspace2, workspace1.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 907670cbb7b8..be700f7b2e99 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -7,24 +7,24 @@ import triton.language as tl import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.fused_moe import ( - get_config_dtype_str, - try_get_optimal_moe_config, -) + get_config_dtype_str, try_get_optimal_moe_config) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache + @triton.jit -def batched_silu_and_mul_kernel(output, # [E, MAX_NUM_TOKENS, D] - input, # [E, MAX_NUM_TOKENS, D * 2] - expert_num_tokens, # [E] - stride_oe, - stride_om, - stride_ie, - stride_im, - compute_type: tl.constexpr, - D, - BLOCK_M: tl.constexpr, - BLOCK_D: tl.constexpr): +def batched_silu_and_mul_kernel( + output, # [E, MAX_NUM_TOKENS, D] + input, # [E, MAX_NUM_TOKENS, D * 2] + expert_num_tokens, # [E] + stride_oe, + stride_om, + stride_ie, + stride_im, + compute_type: tl.constexpr, + D, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr): expert_id = tl.program_id(axis=0) e_num_tokens = tl.load(expert_num_tokens + expert_id) @@ -57,50 +57,53 @@ def batched_silu_and_mul_kernel(output, # [E, MAX_NUM_TOKENS, D] mask_D = offs_D < (D - (d * BLOCK_D)) mask_tile = mask_m & mask_D - x_tile = tl.load(cta_input_ptrs, mask=mask_tile, other=0.0).to(dtype=tl.float32) + x_tile = tl.load(cta_input_ptrs, mask=mask_tile, + other=0.0).to(dtype=tl.float32) y_tile = tl.load(cta_input_ptrs + D, mask=mask_tile, other=0.0) # silu and mul - out_tile = (x_tile * (1.0 / (1.0 + tl.exp(-x_tile)))).to(dtype=compute_type) + out_tile = (x_tile * (1.0 / + (1.0 + tl.exp(-x_tile)))).to(dtype=compute_type) out_tile = out_tile * y_tile tl.store(cta_output_ptrs, out_tile, mask=mask_tile) cta_input_ptrs = cta_input_ptrs + BLOCK_D cta_output_ptrs = cta_output_ptrs + BLOCK_D + @triton.jit def moe_mmk( - a_ptrs, - b_ptrs, - K, - expert_id, - a_scale_ptr, - b_scale_ptr, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_ak, - stride_bk, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Offsets and masks - offs_m, - offs_n, - mask_m, - # Block size for block-wise quantization - group_n: tl.constexpr, - group_k: tl.constexpr, - # Meta-parameters - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - compute_type: tl.constexpr, - use_w8a8: tl.constexpr, - use_w8a16: tl.constexpr): + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + compute_type: tl.constexpr, + use_w8a8: tl.constexpr, + use_w8a16: tl.constexpr): offs_k = tl.arange(0, BLOCK_K) @@ -131,12 +134,9 @@ def moe_mmk( # Load the next block of A and B, generate a mask by checking the # K dimension. a = tl.load(a_ptrs, - mask=mask_m[:, None] & - (offs_k[None, :] < K - k * BLOCK_K), - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_K, + mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K), other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) # We accumulate along the K dimension. if use_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) @@ -177,41 +177,42 @@ def moe_mmk( @triton.jit -def expert_triton_kernel(a_ptr, #[max_tokens, K] - b_ptr, #[K, N] - c_ptr, #[max_tokens, N] - expert_id, - compute_type: tl.constexpr, - # Dimensions - M, - N, - K, - # Quantization data - a_scale_ptr, - b_scale_ptr, - b_zp_ptr, - # strides - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Blockwise quantization data - group_n, - group_k, - # Quantization schemes - use_fp8_w8a8: tl.constexpr, - use_int8_w8a16: tl.constexpr, - # Kernel config - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr): +def expert_triton_kernel( + a_ptr, #[max_tokens, K] + b_ptr, #[K, N] + c_ptr, #[max_tokens, N] + expert_id, + compute_type: tl.constexpr, + # Dimensions + M, + N, + K, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): offs_m = tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) % N @@ -221,7 +222,6 @@ def expert_triton_kernel(a_ptr, #[max_tokens, K] a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn - accumulator = moe_mmk( a_ptrs, b_ptrs, @@ -261,48 +261,50 @@ def expert_triton_kernel(a_ptr, #[max_tokens, K] c_mask = mask_m[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) + @triton.jit -def batched_triton_kernel(a_ptr, # [E, max_num_tokens, K] - b_ptr, # [E, K, N] - c_ptr, # [E, max_num_tokens, N] - expert_num_tokens, # [E] - compute_type: tl.constexpr, - # Dimensions - max_num_tokens, - K, - N, - # Quantization data - a_scale_ptr, - b_scale_ptr, - b_zp_ptr, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_ae, - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_ce, - stride_cm, - stride_cn, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Blockwise quantization data - group_n: tl.constexpr, - group_k: tl.constexpr, - # Quantization schemes - use_fp8_w8a8: tl.constexpr, - use_int8_w8a16: tl.constexpr, - # Kernel config - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr): +def batched_triton_kernel( + a_ptr, # [E, max_num_tokens, K] + b_ptr, # [E, K, N] + c_ptr, # [E, max_num_tokens, N] + expert_num_tokens, # [E] + compute_type: tl.constexpr, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ae, + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_ce, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n: tl.constexpr, + group_k: tl.constexpr, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): expert_id = tl.program_id(axis=0) e_num_tokens = tl.load(expert_num_tokens + expert_id) if e_num_tokens == 0: @@ -310,7 +312,7 @@ def batched_triton_kernel(a_ptr, # [E, max_num_tokens, K] return pid_mn = tl.program_id(axis=1) - num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) + #num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) pid_m = pid_mn // num_pid_n pid_n = pid_mn % num_pid_n @@ -326,58 +328,61 @@ def batched_triton_kernel(a_ptr, # [E, max_num_tokens, K] a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn - c_ptr = c_ptr + expert_id * stride_ce + cta_m_start * stride_cm + cta_n_start * stride_cn - - expert_triton_kernel(a_ptr, - b_ptr, - c_ptr, - expert_id, - compute_type, - cta_m_size, # M - cta_n_size, # N - K, # K - a_scale_ptr, - b_scale_ptr, - b_zp_ptr, - # Strides - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Blockwise quantization data - group_n, - group_k, - # Quantization schemes - use_fp8_w8a8, - use_int8_w8a16, - # Kernel config - BLOCK_M, - BLOCK_N, - BLOCK_K) - - -def invoke_moe_batched_triton_kernel(A: torch.Tensor, # [E, max_tokens, K] - B: torch.Tensor, # [E, K, N] - C: torch.Tensor, # [E, max_tokens, N] - expert_num_tokens: torch.Tensor, # [E] - compute_type: tl.dtype, - # Quantization data - A_scale: torch.Tensor, - B_scale: torch.Tensor, - B_zp: torch.Tensor, - # Quantization schemes - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - config: dict[str, int], - block_shape: Optional[list[int]] = None): + c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm + + cta_n_start * stride_cn) + + expert_triton_kernel( + a_ptr, + b_ptr, + c_ptr, + expert_id, + compute_type, + cta_m_size, # M + cta_n_size, # N + K, # K + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # Strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + # Kernel config + BLOCK_M, + BLOCK_N, + BLOCK_K) + + +def invoke_moe_batched_triton_kernel( + A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, K, N] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: torch.Tensor, + B_scale: torch.Tensor, + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + config: dict[str, int], + block_shape: Optional[list[int]] = None): assert not use_int4_w4a16 max_num_tokens = A.size(1) @@ -389,53 +394,54 @@ def invoke_moe_batched_triton_kernel(A: torch.Tensor, # [E, max_tokens, K] BLOCK_K = config['BLOCK_SIZE_K'] assert max_num_tokens % BLOCK_M == 0 - grid = (expert_num_tokens.size(0), - triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.shape[1], BLOCK_N)) - - batched_triton_kernel[grid](A, - B, - C, - expert_num_tokens, - compute_type, - # Dimensions - max_num_tokens, - K, - N, - # Quantization data - A_scale, - B_scale, - B_zp, - # Strides - A.stride(0), - A.stride(1), - A.stride(2), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(0), - C.stride(1), - C.stride(2), - A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, - A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, - B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, - B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, - B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, - # Blockwise quantization data - 0 if block_shape is None else block_shape[0], - 0 if block_shape is None else block_shape[1], - # Quantization schemes - use_fp8_w8a8, - use_int8_w8a16, - # Kernel config - BLOCK_M = BLOCK_M, - BLOCK_N = BLOCK_N, - BLOCK_K = BLOCK_K) - - -def invoke_batched_silu_and_mul(output : torch.Tensor, #[E, MAX_TOKENS, D] - input: torch.Tensor, #[E, MAX_TOKENS, D * 2] - expert_num_tokens: torch.Tensor): + grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * + triton.cdiv(B.shape[1], BLOCK_N)) + + batched_triton_kernel[grid]( + A, + B, + C, + expert_num_tokens, + compute_type, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + A_scale, + B_scale, + B_zp, + # Strides + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(0), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + # Blockwise quantization data + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + # Kernel config + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K) + +def invoke_batched_silu_and_mul( + output: torch.Tensor, #[E, MAX_TOKENS, D] + input: torch.Tensor, #[E, MAX_TOKENS, D * 2] + expert_num_tokens: torch.Tensor): num_experts = output.size(0) max_num_tokens = output.size(1) @@ -444,24 +450,19 @@ def invoke_batched_silu_and_mul(output : torch.Tensor, #[E, MAX_TOKENS, D] BLOCK_D = 1024 BLOCK_M = 1 - compute_tl_dtype = {torch.float16 : tl.float16, - torch.float32 : tl.float32, - torch.bfloat16 : tl.bfloat16}[output.dtype] + compute_tl_dtype = { + torch.float16: tl.float16, + torch.float32: tl.float32, + torch.bfloat16: tl.bfloat16 + }[output.dtype] #print(f"compute type {compute_tl_dtype}") grid = (num_experts, triton.cdiv(max_num_tokens, BLOCK_M)) - batched_silu_and_mul_kernel[grid](output, - input, - expert_num_tokens, - output.stride(0), - output.stride(1), - input.stride(0), - input.stride(1), - compute_tl_dtype, - D, - BLOCK_M, - BLOCK_D) + batched_silu_and_mul_kernel[grid](output, input, expert_num_tokens, + output.stride(0), output.stride(1), + input.stride(0), input.stride(1), + compute_tl_dtype, D, BLOCK_M, BLOCK_D) class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): @@ -621,8 +622,9 @@ def apply( if num > 0: tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) self.activation( - activation, tmp, hidden_states[expert, :num, :] - @ w1[expert].transpose(0, 1)) + activation, tmp, + hidden_states[expert, :num, :] @ w1[expert].transpose( + 0, 1)) out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) return out @@ -685,15 +687,15 @@ def apply( ) -> torch.Tensor: num_tokens = topk_ids.size(0) - #print_debug = expert_map[0] != -1 and num_tokens < 50 and num_tokens != 1 and False # Check constraints. if self.use_int4_w4a16: assert hidden_states.shape[-1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: - assert hidden_states.shape[-1] == w1.shape[ - 2], f"Hidden size mismatch {hidden_states.shape[-1]} != {w1.shape[2]}" + assert hidden_states.shape[-1] == w1.shape[2], \ + (f"Hidden size mismatch {hidden_states.shape[-1]} " + f"!= {w1.shape[2]}") assert hidden_states.is_contiguous( ), "Hidden_states must be contiguous" @@ -764,7 +766,7 @@ def apply( input=intermediate_cache1, expert_num_tokens=expert_num_tokens) - qintermediate_cache2 = intermediate_cache2 + #qintermediate_cache2 = intermediate_cache2 a2q_scale = a2_scale # TODO (varun) : support w8a8 assert not self.use_fp8_w8a8 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 5bf49a8c2c28..bb7658519bc9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1207,28 +1207,29 @@ def fused_experts(hidden_states: torch.Tensor, block_shape=block_shape) -def fused_experts_impl(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> torch.Tensor: +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> torch.Tensor: # Check constraints. if use_int4_w4a16: assert hidden_states.shape[1] // 2 == w1.shape[ @@ -1631,22 +1632,32 @@ def apply( intermediate_cache3 = _resize_cache(workspace13, (num_tokens, top_k_num, K)) - if hidden_states.dim() == 2: #block_m is None: + if hidden_states.dim() == 2: #block_m is None: sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size( - topk_ids, - config['BLOCK_SIZE_M'], - global_num_experts, expert_map - )) + moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) else: max_num_tokens = hidden_states.shape[1] - sorted_token_ids = torch.arange(0, hidden_states.shape[0] * max_num_tokens, device=hidden_states.device, dtype=torch.int) + sorted_token_ids = torch.arange(0, + hidden_states.shape[0] * + max_num_tokens, + device=hidden_states.device, + dtype=torch.int) sorted_token_ids = sorted_token_ids.flatten() - expert_ids = torch.arange(0, global_num_experts, device=hidden_states.device, dtype=torch.int) - expert_ids = torch.repeat_interleave(expert_ids, max_num_tokens, dim=0) + expert_ids = torch.arange(0, + global_num_experts, + device=hidden_states.device, + dtype=torch.int) + expert_ids = torch.repeat_interleave(expert_ids, + max_num_tokens, + dim=0) print(f"EXPERT_IDS {expert_ids}") - #num_tokens_post_padded = torch.tensor([num_tokens], device=hidden_states.device, dtype=torch.int32) - num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int32) + #num_tokens_post_padded = torch.tensor([num_tokens], + # device=hidden_states.device, + # dtype=torch.int32) + num_tokens_post_padded = torch.zeros(1, + device=hidden_states.device, + dtype=torch.int32) num_tokens_post_padded.fill_(max_num_tokens) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) #print(f"P = {sorted_token_ids}, {hidden_states.shape}") @@ -1705,170 +1716,6 @@ def apply( return intermediate_cache3 -class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__(self, - world_size: int, - rank: int): - super().__init__() - self.world_size = world_size - self.rank = rank - - def dispatch( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - assert topk_ids.dim() == 2 - assert topk_ids.shape[0] == a1.shape[0] - - num_tokens = a1.shape[0] - topk = topk_ids.shape[1] - - tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - max_num_tokens = tokens_per_expert.max() - expert_counts = torch.zeros(num_experts, dtype=torch.int, device=a1.device) - - b_a1 = torch.zeros((num_experts, max_num_tokens, a1.shape[1]), - dtype=a1.dtype, device=a1.device) - - #print(f"START DISPATCH {hex(id(self))}") - - for token in range(num_tokens): - for j in range(topk): - expert_id = topk_ids[token, j] - idx = expert_counts[expert_id] - b_a1[expert_id, idx:idx+1, :] = a1[token, :] - expert_counts[expert_id] = expert_counts[expert_id] + 1 - - #print(f"END DISPATCH {hex(id(self))}: tokens_per_expert {(tokens_per_expert > 0).nonzero().view(-1)}") - - return b_a1, a1_scale, tokens_per_expert - - def combine( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ) -> None: - if False: - print(f"topk_ids {topk_ids.shape}") - print(f"fused_expert_output {fused_expert_output.shape}") - print(f"output {output.shape}") - print(f"counts {self.expert_counts.shape}") - - #print(f"START COMBINE {hex(id(self))}") - - num_tokens, topk = topk_ids.shape - num_experts, _, K = fused_expert_output.shape - expert_counts = torch.zeros(num_experts, dtype=torch.int, device=fused_expert_output.device) - for token in range(num_tokens): - expert_ids = topk_ids[token] - for i in range(topk_ids.shape[1]): - expert_id = expert_ids[i] - if expert_id < num_experts: - idx = expert_counts[expert_id] - output[token, :] = output[token, :] + fused_expert_output[expert_id, idx:idx+1, :] * topk_weights[token, i] - expert_counts[expert_id] = expert_counts[expert_id] + 1 - - #print(f"END COMBINE {hex(id(self))}") - - -def rank_chunk(num, r, w): - rem = num % w - return (num // w) + (1 if r < rem else 0) - - -class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): - - def __init__( - self, - rank: int = 0, - world_size: int = 1, - max_num_tokens: Optional[int] = None, - use_fp8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - block_shape: Optional[List[int]] = None, - block_m: Optional[int] = None, - ): - super().__init__() - assert not use_fp8_w8a8 - assert not use_int4_w4a16 - assert not use_int8_w8a16 - assert block_shape is None - assert block_m is None - self.max_num_tokens = max_num_tokens - self.rank = rank - self.world_size = world_size - - def workspace_shapes( - self, - a: torch.Tensor, - M: int, - N: int, - K: int, - topk: int, - num_experts: int, - ) -> Tuple[int, int, torch.dtype]: - #assert self.max_num_tokens >= a.shape[1] - max_num_tokens = a.shape[1] if self.max_num_tokens is None else self.max_num_tokens - workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack - workspace2 = max_num_tokens * N - return (workspace13, workspace2, a.dtype) - - def apply( - self, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: - #print("START EXPERTS") - assert hidden_states.dim() == 3 - assert expert_num_tokens is not None - num_tokens, topk = topk_ids.shape - _, tmp_max_num_tokens, K = hidden_states.shape - max_num_tokens = tmp_max_num_tokens if self.max_num_tokens is None else self.max_num_tokens - #print(f"global_num_experts = {global_num_experts}") - num_experts = global_num_experts - out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) - num_local_experts = expert_num_tokens.numel() - #print(f"shapes = {hidden_states.shape}, {w1.shape}, {w2.shape}, {out.shape} {expert_num_tokens.shape} {workspace2.shape} {num_experts}") - - # TODO: don't need world_size or rank if expert_base always == 0 - #assert w1.shape[0] == num_experts, f"{w1.shape} == {num_experts}" - #expert_base = rank_chunk(w1.shape[0], self.rank, self.world_size) * self.rank - expert_base = 0 - - for expert in range(num_local_experts): # num_experts - num = expert_num_tokens[expert] - assert num <= max_num_tokens, f"{num}, {max_num_tokens}" - #print(f"{type(num)}, {num}, {max_num_tokens}") - if num > 0: - tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) - self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert_base + expert].transpose(0, 1)) - out[expert, :num, :] = tmp @ w2[expert_base + expert].transpose(0, 1) - - return out - - def modular_triton_fused_moe( use_fp8_w8a8: bool, use_int8_w8a8: bool, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index d5f2b165e8b4..11a143e5f28b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -29,9 +29,10 @@ if current_platform.is_cuda_alike(): from .dispatch_combine import StandardDispatchCombine - from .fused_moe import TritonExperts, fused_experts from .fused_batched_moe import BatchedDispatchCombine, BatchedTritonExperts - from .modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine + from .fused_moe import TritonExperts, fused_experts + from .modular_kernel import (FusedMoEModularKernel, + FusedMoEQuantizeDispatchCombine) from .pplx_dispatch_combine import PplxDispatchCombine else: fused_experts = None # type: ignore @@ -80,7 +81,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError - def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: + def set_dispatch_combine( + self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: return False @abstractmethod @@ -241,29 +243,31 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input) # Maybe extra args - def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: + def set_dispatch_combine( + self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: assert self.fused_experts == fused_experts - block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) + #block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) - if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): - logger.info(f"BatchedTritonExperts {self.moe}") + if isinstance(dispatch_combine, + (BatchedDispatchCombine, PplxDispatchCombine)): + logger.info("BatchedTritonExperts %s", self.moe) experts = BatchedTritonExperts( - use_fp8_w8a8 = False, - use_int8_w8a8 = False, - use_int8_w8a16 = False, - use_int4_w4a16 = False, - block_shape = None, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, ) else: - logger.info(f"TritonExperts {self.moe}") + logger.info("TritonExperts %s", self.moe) experts = TritonExperts( - use_fp8_w8a8 = False, - use_int8_w8a8 = False, - use_int8_w8a16 = False, - use_int4_w4a16 = False, - block_shape = None, - per_channel_quant = False, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + per_channel_quant=False, ) self.fused_experts = FusedMoEModularKernel( @@ -598,8 +602,8 @@ def __init__( dp_rank=self.dp_rank, ep_size=self.ep_size, ep_rank=self.ep_rank, - in_dtype = params_dtype, # this is probably not right, where to get? - out_dtype = params_dtype, # ditto. + in_dtype=params_dtype, # this is probably not right, where to get? + out_dtype=params_dtype, # ditto. ) # Note: get_quant_method will look at the layer's local_num_experts @@ -618,46 +622,41 @@ def __init__( # TODO: move to method? if self.dp_size > 1: logger.info("using pplx dispatch") - max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size + max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size world_size = moe.ep_size - dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. + dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. rank = moe.ep_rank all_to_all = get_all_to_all( max_num_tokens=max_num_tokens, num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk + experts_per_token=moe.experts_per_token, # topk rank=rank, world_size=world_size, dp_size=dp_size, hidden_dim=moe.hidden_dim, hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - # For blocked per token: set to ceil_div(hidden_dim, block_size) * sizeof(float32) + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=( - 0 - if moe.in_dtype.itemsize != 1 - else ( - (moe.hidden_dim + moe.block_size - 1) - // moe.block_size - * torch.float32.itemsize - ) - ) - ) + hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( + (moe.hidden_dim + moe.block_size - 1) // moe.block_size * + torch.float32.itemsize))) dispatch_combine = PplxDispatchCombine( all_to_all, max_num_tokens, world_size, dp_size, - rank, # just for debugging + rank, # just for debugging moe.in_dtype, ) elif True: logger.info("using standard dispatch") dispatch_combine = StandardDispatchCombine( moe.in_dtype, - quant_config.weight_block_size if quant_config is not None else None, + quant_config.weight_block_size + if quant_config is not None else None, ) else: logger.info("using batched dispatch") @@ -668,7 +667,8 @@ def __init__( success = self.quant_method.set_dispatch_combine(dispatch_combine) if not success: - logger.warning("DP+EP not supported for %s.", type(self.quant_method)) + logger.warning("DP+EP not supported for %s.", + type(self.quant_method)) self.apply_router_weight_on_input = apply_router_weight_on_input moe_quant_params = { @@ -1018,12 +1018,14 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, ctx = get_forward_context() max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp - cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu + #cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu num_tokens_across_dp = ctx.dp_metadata.num_tokens_across_dp #In this function we define two ranges: - # 1. chunk_range - The current iteration of the loops's range over the DP world tokens - # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. + # 1. chunk_range - The current iteration of the loops's range over the + # DP world tokens + # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP + # rank owns. moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size @@ -1071,8 +1073,11 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, # TODO: needed for non-pplx? if False and self.dp_size > 1: - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[ - self.dp_rank - 1] + if self.dp_rank == 0: + start = 0 + else: + start = cu_tokens_across_dp_this_iter[self.dp_rank - 1] + end = cu_tokens_across_dp_this_iter[self.dp_rank] all_hidden_states = get_dp_group().all_reduce( @@ -1080,7 +1085,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, final_hidden_states = all_hidden_states[start:end, :] # TODO: needed for non-pplx? - if False and self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + if False and self.reduce_results and (self.tp_size > 1 + or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -1155,7 +1161,8 @@ def forward_impl(self, hidden_states: torch.Tensor, final_hidden_states = all_hidden_states[start:end, :] # TODO: needed for non-pplx? - if False and self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + if False and self.reduce_results and (self.tp_size > 1 + or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index d550c8b040c9..eec5a7406d90 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -67,8 +67,8 @@ def _moe_problem_size( M = a1.shape[0] else: assert a1.dim() == 3 - assert E == a1.shape[0] - M = a1.shape[1] # This is max_num_tokens + assert a1.shape[0] == E + M = a1.shape[1] # This is max_num_tokens assert topk_ids.dim() == 2 topk = topk_ids.shape[1] diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 576c454ec31d..420a81f3f5c8 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -9,6 +9,11 @@ moe_kernel_quantize_input) +def rank_chunk(num, r, w): + rem = num % w + return (num // w) + (1 if r < rem else 0) + + # Note use: layer.get_all_to_all() to get an AllToAll instance # The max_num_tokens, world_size and dp_size must be the same # as the ones used to create the AllToAll. @@ -97,7 +102,7 @@ def dispatch( ) # This argument is optional, defaults to indices.shape[0] - num_tokens = a1.shape[0] # M + num_tokens = a1.shape[0] # M bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) # TODO: optimize this? @@ -123,8 +128,9 @@ def combine( apply_router_weight_on_input: bool, ) -> None: # This argument is optional - num_tokens = output.shape[0] # M - bound_m = torch.tensor([num_tokens], dtype=torch.uint32, + num_tokens = output.shape[0] # M + bound_m = torch.tensor([num_tokens], + dtype=torch.uint32, device=fused_expert_output.device) assert output.shape[0] <= self.max_num_tokens diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index e85f35141602..0d0212b7591c 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -1,36 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 -import importlib.util from typing import List, Optional, Tuple import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - DeepGemmExperts, - _valid_deep_gemm_shape, - _valid_deep_gemm, -) + DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExpert + class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( - self, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - block_shape: Optional[List[int]] = None, - block_m: Optional[int] = None, - allow_deep_gemm: bool = False - ): + def __init__(self, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, + block_m: Optional[int] = None, + allow_deep_gemm: bool = False): super().__init__() - self.triton_expert = TritonExpert( - use_fp8_w8a8, - use_int4_w4a16, - use_int8_w8a16, - block_shape, - block_m - ) + self.triton_expert = TritonExpert(use_fp8_w8a8, use_int4_w4a16, + use_int8_w8a16, block_shape, block_m) self.deep_gemm_expert = DeepGemmExperts() self.allow_deep_gemm = allow_deep_gemm self.use_fp8_w8a8 = use_fp8_w8a8 @@ -48,9 +38,11 @@ def workspace_shapes( # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): - return self.deep_gemm_expert.workspace_shapes(a, M, N, K, topk, num_experts) + return self.deep_gemm_expert.workspace_shapes( + a, M, N, K, topk, num_experts) else: - return self.triton_expert.workspace_shapes(a, M, N, K, topk, num_experts) + return self.triton_expert.workspace_shapes(a, M, N, K, topk, + num_experts) def apply( self, @@ -73,7 +65,7 @@ def apply( ) -> torch.Tensor: N = w1.shape[1] if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 - and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): + and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): return self.deep_gemm_expert( hidden_states, w1, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 70feeb1167f5..6f5505c1273c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -10,8 +10,8 @@ from torch.nn.parameter import Parameter import vllm.envs as envs -from vllm import _custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, @@ -446,7 +446,6 @@ def __init__(self, quant_config: Fp8Config): from vllm.model_executor.layers.fused_moe import fused_experts self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None - self.allow_deep_gemm = allow_deep_gemm # Check for DeepGemm support. self.allow_deep_gemm = False @@ -778,18 +777,21 @@ def process_weights_after_loading(self, layer: Module) -> None: return # Maybe extra args - def set_dispatch_combine(self, dispatch_combine: mk.FusedMoEQuantizeDispatchCombine) -> bool: - from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import TritonOrDeepGemmExperts + def set_dispatch_combine( + self, + dispatch_combine: mk.FusedMoEQuantizeDispatchCombine) -> bool: + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) #block_m = MOE_DP_CHUNK_SIZE * (moe.ep_size // moe.dp_size) #print(f"block_m = {block_m}") experts = TritonOrDeepGemmExperts( - use_fp8_w8a8 = True, - use_int8_w8a16 = False, - use_int4_w4a16 = False, - block_shape = self.quant_config.weight_block_size, - block_m = None, # TODO + use_fp8_w8a8=True, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=self.quant_config.weight_block_size, + block_m=None, # TODO allow_deep_gemm=self.allow_deep_gemm, ) @@ -833,8 +835,8 @@ def apply( return self.fused_experts( hidden_states=x, - layer.w13_weight, - layer.w2_weight, + w1=layer.w13_weight, + w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, From 448658a801099932d4018ae1774129f85e290637 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 23:19:29 +0000 Subject: [PATCH 160/171] more lint stuff Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 5 +++++ .../layers/fused_moe/triton_deep_gemm_moe.py | 10 +++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 11a143e5f28b..dc6e0b50d1b1 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -32,6 +32,7 @@ from .fused_batched_moe import BatchedDispatchCombine, BatchedTritonExperts from .fused_moe import TritonExperts, fused_experts from .modular_kernel import (FusedMoEModularKernel, + FusedMoEPermuteExpertsUnpermute, FusedMoEQuantizeDispatchCombine) from .pplx_dispatch_combine import PplxDispatchCombine else: @@ -249,6 +250,8 @@ def set_dispatch_combine( #block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) + experts: FusedMoEPermuteExpertsUnpermute = None + if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): logger.info("BatchedTritonExperts %s", self.moe) @@ -619,6 +622,8 @@ def __init__( assert quant_method is not None self.quant_method = quant_method + dispatch_combine: FusedMoEQuantizeDispatchCombine = None + # TODO: move to method? if self.dp_size > 1: logger.info("using pplx dispatch") diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 0d0212b7591c..d24ae4768a67 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -6,21 +6,25 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) -from vllm.model_executor.layers.fused_moe.fused_moe import TritonExpert +from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, use_fp8_w8a8: bool, + use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, + per_channel_quant: bool, block_shape: Optional[List[int]] = None, block_m: Optional[int] = None, allow_deep_gemm: bool = False): super().__init__() - self.triton_expert = TritonExpert(use_fp8_w8a8, use_int4_w4a16, - use_int8_w8a16, block_shape, block_m) + self.triton_expert = TritonExperts(use_fp8_w8a8, use_int8_w8a8, + use_int4_w4a16, use_int8_w8a16, + per_channel_quant, block_shape, + block_m) self.deep_gemm_expert = DeepGemmExperts() self.allow_deep_gemm = allow_deep_gemm self.use_fp8_w8a8 = use_fp8_w8a8 From c7ddca42dcc3fb5ef854aea808cfd237b2daba5e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Apr 2025 02:26:57 +0000 Subject: [PATCH 161/171] add guards for pplx import Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 21 +++++++++++++++---- vllm/distributed/parallel_state.py | 12 +++++++---- vllm/model_executor/layers/fused_moe/layer.py | 12 ++++++++--- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index aeedadea3852..ff45c0798cf1 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -10,10 +10,16 @@ import pytest import torch -from pplx_kernels import AllToAll -from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_finalize, nvshmem_get_unique_id, - nvshmem_init) + +try: + from pplx_kernels import AllToAll + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_finalize, nvshmem_get_unique_id, + nvshmem_init) + has_pplx = False +except ImportError as ex: + has_pplx = False + from torch.multiprocessing import ( spawn) # pyright: ignore[reportPrivateImportUsage] from typing_extensions import Concatenate, ParamSpec @@ -45,6 +51,11 @@ reason="Requires multi-node environment", ) +requires_pplx = pytest.mark.skipif( + not has_pplx, + reason="Requires PPLX kernels", +) + @dataclasses.dataclass class ProcessGroupInfo: @@ -420,6 +431,7 @@ def _pplx_dispatch_combine( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) +@pytest.mark.skipif(not has_pplx, reason="PPLX kernels not available.") def test_pplx_dispatch_combine( m: int, n: int, @@ -543,6 +555,7 @@ def _pplx_moe( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) +@pytest.mark.skipif(not has_pplx, reason="PPLX kernels not available.") def test_pplx_moe( m: int, n: int, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index cf715681c878..c2bd6dba5375 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -23,6 +23,7 @@ """ import contextlib import gc +import importlib import pickle import weakref from collections import namedtuple @@ -34,9 +35,6 @@ import torch import torch.distributed -from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_finalize, nvshmem_get_unique_id, - nvshmem_init) from torch.distributed import Backend, ProcessGroup import vllm.envs as envs @@ -920,7 +918,12 @@ def init_distributed_environment( @run_once def pplx_init(rank, world_size): - if world_size > 1: + has_pplx = importlib.util.find_spec("pplx_kernels") is not None + + if has_pplx and world_size > 1: + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, + nvshmem_init) try: global PPLX_DID_INIT logger.debug(f"PPLX_INIT {rank} {world_size}") @@ -940,6 +943,7 @@ def pplx_init(rank, world_size): def pplx_finalize(): global PPLX_DID_INIT if PPLX_DID_INIT: + from pplx_kernels.nvshmem import nvshmem_finalize nvshmem_finalize() diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index dc6e0b50d1b1..010568c59065 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import importlib import threading import weakref from abc import abstractmethod @@ -7,7 +8,6 @@ from enum import Enum from typing import Callable, List, Optional, Tuple -import pplx_kernels as pplx # TODO: guard this import torch import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter @@ -27,6 +27,8 @@ from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op +has_pplx = importlib.util.find_spec("pplx_kernels") is not None + if current_platform.is_cuda_alike(): from .dispatch_combine import StandardDispatchCombine from .fused_batched_moe import BatchedDispatchCombine, BatchedTritonExperts @@ -34,7 +36,8 @@ from .modular_kernel import (FusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute, FusedMoEQuantizeDispatchCombine) - from .pplx_dispatch_combine import PplxDispatchCombine + if has_pplx: + from .pplx_dispatch_combine import PplxDispatchCombine else: fused_experts = None # type: ignore if current_platform.is_tpu(): @@ -115,6 +118,9 @@ def __init__(self): self._lock = threading.RLock() # Reentrant lock for thread safety def get_or_create(self, **kwargs): + assert has_pplx + import pplx_kernels as pplx + # Create a hashable key from the kwargs key = tuple(sorted((k, v) for k, v in kwargs.items())) @@ -625,7 +631,7 @@ def __init__( dispatch_combine: FusedMoEQuantizeDispatchCombine = None # TODO: move to method? - if self.dp_size > 1: + if self.dp_size > 1 and has_pplx: logger.info("using pplx dispatch") max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size world_size = moe.ep_size From 22b988a95a4eefca55cb8aabcb2a6dc44ae6d256 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 30 Apr 2025 10:55:48 -0400 Subject: [PATCH 162/171] fix forward_chunked Signed-off-by: Varun Sundar Rabindranath Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 62 +++++-------------- 1 file changed, 15 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 010568c59065..22ba68de5483 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1023,40 +1023,16 @@ def forward(self, hidden_states: torch.Tensor, return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) + def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): - ctx = get_forward_context() - - max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp - #cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu - num_tokens_across_dp = ctx.dp_metadata.num_tokens_across_dp - - #In this function we define two ranges: - # 1. chunk_range - The current iteration of the loops's range over the - # DP world tokens - # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP - # rank owns. - - moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size - - num_tokens_remaining_across_dp = num_tokens_across_dp - chunk_start = 0 - chunk_end = min(moe_dp_chunk_size_per_rank, - full_hidden_states.shape[0]) full_final_hidden_states = torch.empty_like(full_hidden_states) - assert full_hidden_states.shape[0] == full_router_logits.shape[0] - - for iter in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): + def process_chunk(chunk_start, chunk_end, skip_result_store = False): hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - cu_tokens_across_dp_this_iter = torch.cumsum( - num_tokens_remaining_across_dp.clamp( - max=moe_dp_chunk_size_per_rank), - dim=0) - # TODO: still may be needed for non-pplx, put into dispatcher class. if False: hidden_states = self.naive_multicast( @@ -1102,30 +1078,22 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - full_final_hidden_states[chunk_start:chunk_end, :].copy_( - final_hidden_states) - - # Update bounds - num_tokens_remaining_across_dp = torch.clamp( - num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, - min=0) + if not skip_result_store: + full_final_hidden_states[chunk_start:chunk_end, :].copy_( + final_hidden_states) - # HACK FIX - if num_tokens_remaining_across_dp.sum() == 0: - break + max_tokens_across_dp = get_forward_context().dp_metadata.max_tokens_across_dp + moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size - def update_chunk_bound(x: int): - return min(x + moe_dp_chunk_size_per_rank, - full_hidden_states.shape[0]) + num_tokens = full_hidden_states.size(0) + for chunk_start_ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): + chunk_start = chunk_start_ + chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dp) + # clamp start and end + chunk_start = min(chunk_start, num_tokens - 1) + chunk_end = min(chunk_end, num_tokens) - #chunk_start = update_chunk_bound(chunk_start) - #chunk_end = update_chunk_bound(chunk_end) - if chunk_end == full_hidden_states.shape[0]: - # simply redo computation - pass - else: - chunk_start = update_chunk_bound(chunk_start) - chunk_end = update_chunk_bound(chunk_end) + process_chunk(chunk_start, chunk_end, skip_result_store = chunk_start_ >= num_tokens) return full_final_hidden_states From c09cefd0f8b068ce3dcd41c8c07342bcc408aab9 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Apr 2025 17:04:54 +0000 Subject: [PATCH 163/171] fix more lint Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 2 +- vllm/distributed/parallel_state.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index ff45c0798cf1..9557758f0ed1 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -17,7 +17,7 @@ nvshmem_finalize, nvshmem_get_unique_id, nvshmem_init) has_pplx = False -except ImportError as ex: +except ImportError: has_pplx = False from torch.multiprocessing import ( diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c2bd6dba5375..6a3725b88c8f 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -922,16 +922,15 @@ def pplx_init(rank, world_size): if has_pplx and world_size > 1: from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_get_unique_id, - nvshmem_init) + nvshmem_get_unique_id, nvshmem_init) try: global PPLX_DID_INIT - logger.debug(f"PPLX_INIT {rank} {world_size}") + logger.debug("PPLX_INIT %s %d", rank, world_size) uid = nvshmem_get_unique_id( ) if rank == 0 else nvshmem_alloc_empty_unique_id() uid_gpu = uid.cuda() get_world_group().broadcast(uid_gpu, src=0) - logger.debug(f"PPLX_INIT UID={uid_gpu}") + logger.debug("PPLX_INIT UID = %s", uid_gpu) uid = uid_gpu.to(device='cpu') nvshmem_init(uid, rank, world_size) PPLX_DID_INIT = True @@ -944,6 +943,7 @@ def pplx_finalize(): global PPLX_DID_INIT if PPLX_DID_INIT: from pplx_kernels.nvshmem import nvshmem_finalize + logger.debug("PPLX finalize") nvshmem_finalize() From 938c516f499e073a4dcee22790e23fb21d6f10e6 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Apr 2025 21:27:29 +0000 Subject: [PATCH 164/171] cleanups Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 48 +++++----- vllm/forward_context.py | 2 +- .../layers/fused_moe/deep_gemm_moe.py | 2 + vllm/model_executor/layers/fused_moe/layer.py | 94 ++++++++++--------- .../layers/fused_moe/pplx_dispatch_combine.py | 14 ++- .../layers/fused_moe/triton_deep_gemm_moe.py | 10 +- .../model_executor/layers/quantization/fp8.py | 7 -- 7 files changed, 90 insertions(+), 87 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 9557758f0ed1..6dd028894b34 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -16,7 +16,7 @@ from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, nvshmem_finalize, nvshmem_get_unique_id, nvshmem_init) - has_pplx = False + has_pplx = True except ImportError: has_pplx = False @@ -46,11 +46,6 @@ P = ParamSpec("P") -require_multi_node = pytest.mark.skipif( - "MASTER_ADDR" not in os.environ, - reason="Requires multi-node environment", -) - requires_pplx = pytest.mark.skipif( not has_pplx, reason="Requires PPLX kernels", @@ -180,6 +175,9 @@ def torch_dispatch( tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + + assert tokens_per_expert.numel() == num_experts + if max_num_tokens is None: max_num_tokens = int(tokens_per_expert.max().item()) @@ -259,7 +257,7 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) -@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -309,7 +307,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): rank = pgi.rank world_size = pgi.world_size rank_num_tokens = rank_chunk(num_tokens, rank, world_size) - max_num_tokens = num_tokens + max_num_tokens = max(num_tokens, 1) ata = AllToAll.internode( max_num_tokens=max_num_tokens, @@ -350,22 +348,23 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): False, ) - naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, - num_experts) + if False: + naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, + num_experts) - torch.distributed.all_reduce(tokens_per_expert) - tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, - world_size).to(dtype=torch.int32) + torch.distributed.all_reduce(tokens_per_expert) + tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, + world_size).to(dtype=torch.int32) - torch.testing.assert_close(tokens_per_expert, - expert_num_tokens, - atol=0, - rtol=0) + torch.testing.assert_close(tokens_per_expert, + expert_num_tokens, + atol=0, + rtol=0) b_a = b_a * 1.5 out = torch.full( - (rank_num_tokens * world_size, hidden_dim), + (rank_num_tokens, hidden_dim), torch.nan, dtype=a.dtype, device=device, @@ -424,14 +423,15 @@ def _pplx_dispatch_combine( nvshmem_finalize() +# TODO: M < world_size doesn't appear to be supported by pplx? @pytest.mark.parametrize("m", [4, 32, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions? % 128? +@pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) -@pytest.mark.skipif(not has_pplx, reason="PPLX kernels not available.") +@requires_pplx def test_pplx_dispatch_combine( m: int, n: int, @@ -502,11 +502,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): # Chunking weights like this only works for batched format chunk_by_rank(w1, rank, world_size), chunk_by_rank(w2, rank, world_size), - #w1, - #w2, chunk_topk_weight, chunk_topk_ids, - global_num_experts=num_experts #? num_local_experts? + global_num_experts=num_experts ) torch.cuda.synchronize() @@ -547,7 +545,7 @@ def _pplx_moe( nvshmem_finalize() -# TODO: M == 1 doesn't work +# TODO: M < world_size doesn't appear to be supported by pplx? @pytest.mark.parametrize("m", [2, 3, 32, 45, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) @@ -555,7 +553,7 @@ def _pplx_moe( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) -@pytest.mark.skipif(not has_pplx, reason="PPLX kernels not available.") +@requires_pplx def test_pplx_moe( m: int, n: int, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 7f6ad5c261cc..8e03541db873 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -92,7 +92,7 @@ def set_forward_context(attn_metadata: Any, dtype=torch.int32) from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) - #TODO device? + #TODO device? (tms) max_tokens_across_dp = torch.max( num_tokens_tensor) #.to(device="cuda") cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 266ba3bfa07a..4a0fb374bd41 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import functools import importlib.util from typing import Optional, Tuple @@ -19,6 +20,7 @@ has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None +@functools.cache def deep_gemm_block_shape() -> list[int]: # Lazy import to avoid CUDA initialization problems. import deep_gemm as dg diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 22ba68de5483..91f7d50492de 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -63,8 +63,7 @@ class MoEConfig: ep_size: int ep_rank: int - in_dtype: torch.dtype - out_dtype: torch.dtype + in_dtype: torch.dtype # The activation type. # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 @@ -142,7 +141,6 @@ def get_all_to_all(**kwargs): return _all_to_all_cache.get_or_create(**kwargs) -#TODO: Every change in this class is a broken hack!! @CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" @@ -249,18 +247,15 @@ def apply( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) - # Maybe extra args def set_dispatch_combine( self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: assert self.fused_experts == fused_experts - #block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) - experts: FusedMoEPermuteExpertsUnpermute = None if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): - logger.info("BatchedTritonExperts %s", self.moe) + logger.debug("BatchedTritonExperts %s", self.moe) experts = BatchedTritonExperts( use_fp8_w8a8=False, use_int8_w8a8=False, @@ -269,7 +264,7 @@ def set_dispatch_combine( block_shape=None, ) else: - logger.info("TritonExperts %s", self.moe) + logger.debug("TritonExperts %s", self.moe) experts = TritonExperts( use_fp8_w8a8=False, use_int8_w8a8=False, @@ -611,8 +606,7 @@ def __init__( dp_rank=self.dp_rank, ep_size=self.ep_size, ep_rank=self.ep_rank, - in_dtype=params_dtype, # this is probably not right, where to get? - out_dtype=params_dtype, # ditto. + in_dtype=params_dtype, # TODO: is this right? ) # Note: get_quant_method will look at the layer's local_num_experts @@ -628,12 +622,42 @@ def __init__( assert quant_method is not None self.quant_method = quant_method - dispatch_combine: FusedMoEQuantizeDispatchCombine = None + dispatch_combine = self._construct_dispatch_combine( + moe, quant_config) + + success = self.quant_method.set_dispatch_combine(dispatch_combine) + + if not success: + logger.warning("DP+EP not supported for %s.", + type(self.quant_method)) + + self.apply_router_weight_on_input = apply_router_weight_on_input + moe_quant_params = { + "num_experts": self.local_num_experts, + "hidden_size": hidden_size, + "intermediate_size_per_partition": + self.intermediate_size_per_partition, + "params_dtype": params_dtype, + "weight_loader": self.weight_loader, + } + # need full intermediate size pre-sharding for WNA16 act order + if (self.quant_method.__class__.__name__ + in ("GPTQMarlinMoEMethod", + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod")): + moe_quant_params["intermediate_size_full"] = intermediate_size + + self.quant_method.create_weights(layer=self, **moe_quant_params) - # TODO: move to method? + # TODO: return Optional? + def _construct_dispatch_combine( + self, + moe: MoEConfig, + quant_config: Optional[QuantizationConfig], + ) -> FusedMoEQuantizeDispatchCombine: if self.dp_size > 1 and has_pplx: - logger.info("using pplx dispatch") - max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size + logger.debug("using pplx dispatch") + max_num_tokens = MOE_DP_CHUNK_SIZE world_size = moe.ep_size dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. rank = moe.ep_rank @@ -654,51 +678,28 @@ def __init__( (moe.hidden_dim + moe.block_size - 1) // moe.block_size * torch.float32.itemsize))) - dispatch_combine = PplxDispatchCombine( + return PplxDispatchCombine( all_to_all, max_num_tokens, world_size, dp_size, - rank, # just for debugging + rank, moe.in_dtype, ) elif True: - logger.info("using standard dispatch") - dispatch_combine = StandardDispatchCombine( + logger.debug("using standard dispatch") + return StandardDispatchCombine( moe.in_dtype, quant_config.weight_block_size if quant_config is not None else None, ) else: - logger.info("using batched dispatch") - dispatch_combine = BatchedDispatchCombine( + logger.debug("using batched dispatch") + return BatchedDispatchCombine( moe.ep_size, moe.ep_rank, ) - success = self.quant_method.set_dispatch_combine(dispatch_combine) - if not success: - logger.warning("DP+EP not supported for %s.", - type(self.quant_method)) - - self.apply_router_weight_on_input = apply_router_weight_on_input - moe_quant_params = { - "num_experts": self.local_num_experts, - "hidden_size": hidden_size, - "intermediate_size_per_partition": - self.intermediate_size_per_partition, - "params_dtype": params_dtype, - "weight_loader": self.weight_loader, - } - # need full intermediate size pre-sharding for WNA16 act order - if (self.quant_method.__class__.__name__ - in ("GPTQMarlinMoEMethod", - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod")): - moe_quant_params["intermediate_size_full"] = intermediate_size - - self.quant_method.create_weights(layer=self, **moe_quant_params) - def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, @@ -1015,9 +1016,14 @@ def naive_multicast(self, x: torch.Tensor, return buffer + # TODO: will this be cudagraph-able? (probably not) + # This should not be necessary. + def invalid_pplx(self, hidden_states: torch.Tensor) -> bool: + return has_pplx and hidden_states.shape[0] < self.dp_size + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - if self.use_direct_call: + if self.use_direct_call or self.invalid_pplx(hidden_states): return self.forward_impl(hidden_states, router_logits) else: return torch.ops.vllm.moe_forward(hidden_states, router_logits, diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 420a81f3f5c8..4c00edd0b3d8 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -28,6 +28,7 @@ def __init__(self, quant_dtype: Optional[torch.dtype] = None, block_shape: Optional[List[int]] = None): super().__init__() + assert max_num_tokens > 0 self.a2a = a2a self.block_shape = block_shape self.max_num_tokens = max_num_tokens @@ -47,13 +48,15 @@ def dispatch( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - # Is this always going to be a1.device? - device = a1.device + num_tokens = a1.shape[0] # M hidden_dim = a1.shape[-1] # K - # ?? + assert rank_topk_ids.shape[0] == num_tokens # assert expert_map is None, "NYI" + # Is this always going to be a1.device? + device = a1.device + if apply_router_weight_on_input: topk = rank_topk_ids.shape[1] # TODO: this only works for topK=1, will need to update for topK>1 @@ -102,7 +105,6 @@ def dispatch( ) # This argument is optional, defaults to indices.shape[0] - num_tokens = a1.shape[0] # M bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) # TODO: optimize this? @@ -133,7 +135,9 @@ def combine( dtype=torch.uint32, device=fused_expert_output.device) - assert output.shape[0] <= self.max_num_tokens + assert topk_ids.shape[0] <= num_tokens + assert output.shape[0] <= self.max_num_tokens, \ + f"{output.shape[0]} <= {self.max_num_tokens}" assert output.shape[1] == fused_expert_output.shape[-1] # Set weights to 1 if we did them in dispatch. This is hacky. diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index d24ae4768a67..5ddb0e668423 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -12,11 +12,11 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, block_shape: Optional[List[int]] = None, block_m: Optional[int] = None, allow_deep_gemm: bool = False): diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 6f5505c1273c..acf43284a932 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -776,22 +776,15 @@ def process_weights_after_loading(self, layer: Module) -> None: requires_grad=False) return - # Maybe extra args def set_dispatch_combine( self, dispatch_combine: mk.FusedMoEQuantizeDispatchCombine) -> bool: from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts) - #block_m = MOE_DP_CHUNK_SIZE * (moe.ep_size // moe.dp_size) - #print(f"block_m = {block_m}") - experts = TritonOrDeepGemmExperts( use_fp8_w8a8=True, - use_int8_w8a16=False, - use_int4_w4a16=False, block_shape=self.quant_config.weight_block_size, - block_m=None, # TODO allow_deep_gemm=self.allow_deep_gemm, ) From c0fc027a515ec723e4dc633eeb1ccdcb534e0008 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Apr 2025 21:32:39 +0000 Subject: [PATCH 165/171] cleanups + lint, layer.py wip Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 3 +-- vllm/model_executor/layers/fused_moe/layer.py | 21 +++++++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 6dd028894b34..111a5a30176d 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -504,8 +504,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): chunk_by_rank(w2, rank, world_size), chunk_topk_weight, chunk_topk_ids, - global_num_experts=num_experts - ) + global_num_experts=num_experts) torch.cuda.synchronize() diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 91f7d50492de..8022389b73c8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -622,8 +622,7 @@ def __init__( assert quant_method is not None self.quant_method = quant_method - dispatch_combine = self._construct_dispatch_combine( - moe, quant_config) + dispatch_combine = self._construct_dispatch_combine(moe, quant_config) success = self.quant_method.set_dispatch_combine(dispatch_combine) @@ -1029,13 +1028,12 @@ def forward(self, hidden_states: torch.Tensor, return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) - def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): full_final_hidden_states = torch.empty_like(full_hidden_states) - def process_chunk(chunk_start, chunk_end, skip_result_store = False): + def process_chunk(chunk_start, chunk_end, skip_result_store=False): hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] @@ -1088,18 +1086,23 @@ def process_chunk(chunk_start, chunk_end, skip_result_store = False): full_final_hidden_states[chunk_start:chunk_end, :].copy_( final_hidden_states) - max_tokens_across_dp = get_forward_context().dp_metadata.max_tokens_across_dp + max_tokens_across_dp = get_forward_context( + ).dp_metadata.max_tokens_across_dp moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size num_tokens = full_hidden_states.size(0) - for chunk_start_ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): - chunk_start = chunk_start_ - chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dp) + for chunk_start_ in range(0, max_tokens_across_dp, + moe_dp_chunk_size_per_rank): + chunk_start = chunk_start_ + chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, + max_tokens_across_dp) # clamp start and end chunk_start = min(chunk_start, num_tokens - 1) chunk_end = min(chunk_end, num_tokens) - process_chunk(chunk_start, chunk_end, skip_result_store = chunk_start_ >= num_tokens) + process_chunk(chunk_start, + chunk_end, + skip_result_store=chunk_start_ >= num_tokens) return full_final_hidden_states From f74ab61d6b0c496c402b17c1ed18c5f1041d66de Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Apr 2025 21:43:57 +0000 Subject: [PATCH 166/171] fix parallel_state lint Signed-off-by: Bill Nell --- vllm/distributed/parallel_state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6a3725b88c8f..2cedaa060189 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -23,7 +23,7 @@ """ import contextlib import gc -import importlib +import importlib.util import pickle import weakref from collections import namedtuple @@ -925,7 +925,7 @@ def pplx_init(rank, world_size): nvshmem_get_unique_id, nvshmem_init) try: global PPLX_DID_INIT - logger.debug("PPLX_INIT %s %d", rank, world_size) + logger.info("PPLX_INIT rank=%d world=%d", rank, world_size) uid = nvshmem_get_unique_id( ) if rank == 0 else nvshmem_alloc_empty_unique_id() uid_gpu = uid.cuda() @@ -943,7 +943,7 @@ def pplx_finalize(): global PPLX_DID_INIT if PPLX_DID_INIT: from pplx_kernels.nvshmem import nvshmem_finalize - logger.debug("PPLX finalize") + logger.info("PPLX finalize") nvshmem_finalize() From 3e8a0e36695085057663eb519217aec72e0b0856 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 1 May 2025 02:48:00 +0000 Subject: [PATCH 167/171] fix M=1 pplx test Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 106 +++++++++--------- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 2 files changed, 56 insertions(+), 52 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 111a5a30176d..b6c15b1a2bba 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -297,18 +297,24 @@ def chunk_by_rank(t, r, w): return t[(r * chunk):(r + 1) * chunk] -def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): +ata = None + +def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): assert torch.cuda.current_device() == pgi.local_rank + topk = topk_ids.shape[1] + + #tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + num_tokens, hidden_dim = a.shape - num_experts = w1.shape[0] block_size = 128 device = pgi.device rank = pgi.rank world_size = pgi.world_size - rank_num_tokens = rank_chunk(num_tokens, rank, world_size) - max_num_tokens = max(num_tokens, 1) + max_num_tokens = rank_chunk(num_tokens, 0, world_size) + print(f"MAX_NUM_TOKENS = {max_num_tokens}") + global ata ata = AllToAll.internode( max_num_tokens=max_num_tokens, num_experts=num_experts, @@ -333,9 +339,11 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): ) a_chunk = chunk_by_rank(a, rank, world_size).to(device) - score_chunk = chunk_by_rank(scores, rank, world_size).to(device) - chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, - False) + num_tokens = a_chunk.shape[0] + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + + print(f"{rank}: shapes {a_chunk.shape}, {chunk_topk_weight.shape}, {chunk_topk_ids.shape}, E={num_experts}") b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( a_chunk, @@ -343,11 +351,13 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): None, chunk_topk_weight, chunk_topk_ids, - num_experts, # store at PplxDispatchCombine creation? + num_experts, None, False, ) + #torch.cuda.synchronize() + if False: naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, num_experts) @@ -364,7 +374,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): b_a = b_a * 1.5 out = torch.full( - (rank_num_tokens, hidden_dim), + (max_num_tokens, hidden_dim), torch.nan, dtype=a.dtype, device=device, @@ -377,22 +387,21 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): chunk_topk_ids, False, ) - torch.cuda.synchronize() - ata.destroy() + #torch.cuda.synchronize() + + #ata.destroy() - return out[:rank_num_tokens] + return out[:num_tokens] def _pplx_dispatch_combine( pgi: ProcessGroupInfo, dp_size: int, - m, - n, - k, - e, - topk: int, - dtype: torch.dtype, + a, + topk_weight, + topk_ids, + num_experts, ): uid = nvshmem_get_unique_id( ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() @@ -400,37 +409,34 @@ def _pplx_dispatch_combine( nvshmem_init(uid, pgi.rank, pgi.world_size) device = pgi.device - a = torch.randn((m, k), device=device, dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 - score = torch.randn((m, e), device=device, dtype=dtype) - - topk_weight, topk_ids = fused_topk(a, score, topk, False) + k = a.shape[1] + topk = topk_ids.shape[1] - a_rep = torch.repeat_interleave(a, topk, dim=0) + a_rep = torch.repeat_interleave(a, topk, dim=0).to(device) torch_output = (a_rep.view(-1, topk, k) * 1.5 * - topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) + topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(a.dtype) - pplx_output = torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, score, - topk) + pplx_output = pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts) torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + print(f"{pgi.rank}: out shapes {pplx_output.shape}, {torch_output.shape}") + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() # TODO: M < world_size doesn't appear to be supported by pplx? -@pytest.mark.parametrize("m", [4, 32, 64, 222]) +@pytest.mark.parametrize("m", [1, 4, 32, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #[[4, 2]]) @requires_pplx def test_pplx_dispatch_combine( m: int, @@ -443,22 +449,27 @@ def test_pplx_dispatch_combine( ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size + device = "cuda" + + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + score = torch.randn((m, e), device=device, dtype=dtype) + + topk_weight, topk_ids = fused_topk(a, score, topk, False) - parallel_launch(world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, - topk, dtype) + parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, topk_weight, topk_ids, e) -def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): +def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): assert torch.cuda.current_device() == pgi.local_rank - num_tokens, hidden_dim = a.shape + hidden_dim = a.shape[1] num_experts = w1.shape[0] block_size = 128 device = pgi.device rank = pgi.rank world_size = pgi.world_size - rank_num_tokens = rank_chunk(num_tokens, rank, world_size) - max_num_tokens = num_tokens + topk = topk_ids.shape[1] + max_num_tokens = rank_chunk(a.shape[0], 0, world_size) ata = AllToAll.internode( max_num_tokens=max_num_tokens, @@ -474,9 +485,6 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): torch.float32.itemsize)), ) - w1 = w1.to(device) - w2 = w2.to(device) - dispatch_combine = PplxDispatchCombine( ata, max_num_tokens, @@ -493,15 +501,14 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): ) a_chunk = chunk_by_rank(a, rank, world_size).to(device) - score_chunk = chunk_by_rank(scores, rank, world_size).to(device) - chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, - False) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) out = fused_experts( a_chunk, # Chunking weights like this only works for batched format - chunk_by_rank(w1, rank, world_size), - chunk_by_rank(w2, rank, world_size), + chunk_by_rank(w1, rank, world_size).to(device), + chunk_by_rank(w2, rank, world_size).to(device), chunk_topk_weight, chunk_topk_ids, global_num_experts=num_experts) @@ -510,7 +517,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): ata.destroy() - return out[:rank_num_tokens] + return out def _pplx_moe( @@ -521,7 +528,6 @@ def _pplx_moe( w2: torch.Tensor, score: torch.Tensor, topk: int, - dtype: torch.dtype, ): uid = nvshmem_get_unique_id( ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() @@ -534,7 +540,7 @@ def _pplx_moe( with set_current_vllm_config(vllm_config): topk_weight, topk_ids = fused_topk(a, score, topk, False) torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - pplx_output = torch_pplx_moe(pgi, dp_size, a, w1, w2, score, topk) + pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) @@ -544,8 +550,7 @@ def _pplx_moe( nvshmem_finalize() -# TODO: M < world_size doesn't appear to be supported by pplx? -@pytest.mark.parametrize("m", [2, 3, 32, 45, 64, 222]) +@pytest.mark.parametrize("m", [1, 2, 3, 32, 45, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -569,5 +574,4 @@ def test_pplx_moe( w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) - parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, - dtype) + parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8022389b73c8..f5c0a452a2bf 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -251,7 +251,7 @@ def set_dispatch_combine( self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: assert self.fused_experts == fused_experts - experts: FusedMoEPermuteExpertsUnpermute = None + experts: Optional[FusedMoEPermuteExpertsUnpermute] = None if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): From 886045e54d1d3a5aedef42ea6c4e3916fe3f7e1d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 1 May 2025 04:04:24 +0000 Subject: [PATCH 168/171] fix M=1 pplx test Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 68 +++++++++--------------------- 1 file changed, 19 insertions(+), 49 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index b6c15b1a2bba..26021d201937 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -28,7 +28,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedExperts) + BatchedExperts, BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) @@ -293,34 +293,26 @@ def rank_chunk(num, r, w): def chunk_by_rank(t, r, w): chunk = rank_chunk(t.shape[0], r, w) - #print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") return t[(r * chunk):(r + 1) * chunk] -ata = None - def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): assert torch.cuda.current_device() == pgi.local_rank topk = topk_ids.shape[1] - - #tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - - num_tokens, hidden_dim = a.shape + num_tokens, hidden_dim = a.shape[1] block_size = 128 device = pgi.device rank = pgi.rank world_size = pgi.world_size max_num_tokens = rank_chunk(num_tokens, 0, world_size) - print(f"MAX_NUM_TOKENS = {max_num_tokens}") - global ata ata = AllToAll.internode( max_num_tokens=max_num_tokens, num_experts=num_experts, experts_per_token=topk, rank=rank, - world_size=pgi.world_size, + world_size=world_size, dp_size=dp_size, hidden_dim=hidden_dim, hidden_dim_bytes=hidden_dim * a.dtype.itemsize, @@ -332,19 +324,15 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): dispatch_combine = PplxDispatchCombine( ata, max_num_tokens, - pgi.world_size, + world_size, dp_size, rank, - a.dtype, ) a_chunk = chunk_by_rank(a, rank, world_size).to(device) - num_tokens = a_chunk.shape[0] chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) - print(f"{rank}: shapes {a_chunk.shape}, {chunk_topk_weight.shape}, {chunk_topk_ids.shape}, E={num_experts}") - b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( a_chunk, None, @@ -356,21 +344,6 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): False, ) - #torch.cuda.synchronize() - - if False: - naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, - num_experts) - - torch.distributed.all_reduce(tokens_per_expert) - tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, - world_size).to(dtype=torch.int32) - - torch.testing.assert_close(tokens_per_expert, - expert_num_tokens, - atol=0, - rtol=0) - b_a = b_a * 1.5 out = torch.full( @@ -388,9 +361,11 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): False, ) - #torch.cuda.synchronize() + torch.cuda.synchronize() - #ata.destroy() + ata.destroy() + + num_tokens = a_chunk.shape[0] return out[:num_tokens] @@ -399,8 +374,8 @@ def _pplx_dispatch_combine( pgi: ProcessGroupInfo, dp_size: int, a, - topk_weight, - topk_ids, + score, + topk, num_experts, ): uid = nvshmem_get_unique_id( @@ -409,8 +384,8 @@ def _pplx_dispatch_combine( nvshmem_init(uid, pgi.rank, pgi.world_size) device = pgi.device + topk_weight, topk_ids = fused_topk(a, score, topk, False) k = a.shape[1] - topk = topk_ids.shape[1] a_rep = torch.repeat_interleave(a, topk, dim=0).to(device) @@ -422,21 +397,19 @@ def _pplx_dispatch_combine( torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) - print(f"{pgi.rank}: out shapes {pplx_output.shape}, {torch_output.shape}") - torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() -# TODO: M < world_size doesn't appear to be supported by pplx? -@pytest.mark.parametrize("m", [1, 4, 32, 64, 222]) +# TODO: this test point does not work for M == 1 +@pytest.mark.parametrize("m", [4, 32, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #[[4, 2]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @requires_pplx def test_pplx_dispatch_combine( m: int, @@ -450,13 +423,10 @@ def test_pplx_dispatch_combine( current_platform.seed_everything(7) world_size, dp_size = world_dp_size device = "cuda" - a = torch.randn((m, k), device=device, dtype=dtype) / 10 score = torch.randn((m, e), device=device, dtype=dtype) - topk_weight, topk_ids = fused_topk(a, score, topk, False) - - parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, topk_weight, topk_ids, e) + parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, score, topk, e) def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): @@ -476,7 +446,7 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): num_experts=num_experts, experts_per_token=topk, rank=rank, - world_size=pgi.world_size, + world_size=world_size, dp_size=dp_size, hidden_dim=hidden_dim, hidden_dim_bytes=hidden_dim * a.dtype.itemsize, @@ -488,12 +458,12 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): dispatch_combine = PplxDispatchCombine( ata, max_num_tokens, - pgi.world_size, + world_size, dp_size, rank, ) - experts = BatchedExperts(max_num_tokens) + experts = BatchedExperts(a.shape[0]) fused_experts = FusedMoEModularKernel( dispatch_combine, @@ -556,7 +526,7 @@ def _pplx_moe( @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @requires_pplx def test_pplx_moe( m: int, From 5d960dfe308119d3fde33933993e2d5a59ac8512 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 1 May 2025 04:04:45 +0000 Subject: [PATCH 169/171] fix M=1 pplx test Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 26021d201937..d7916b31d3c7 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -300,7 +300,7 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): assert torch.cuda.current_device() == pgi.local_rank topk = topk_ids.shape[1] - num_tokens, hidden_dim = a.shape[1] + num_tokens, hidden_dim = a.shape block_size = 128 device = pgi.device rank = pgi.rank From 1014679646373b5a176649f505c2e18a2fbd9fef Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 1 May 2025 16:05:58 -0400 Subject: [PATCH 170/171] fixes Signed-off-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_pplx_moe.py | 5 +- vllm/model_executor/layers/fused_moe/layer.py | 386 ++++++++++++------ .../layers/fused_moe/pplx_dispatch_combine.py | 7 +- vllm/model_executor/models/deepseek_v2.py | 22 +- vllm/model_executor/models/llama4.py | 2 +- vllm/model_executor/models/qwen2_moe.py | 2 +- vllm/model_executor/models/qwen3_moe.py | 2 +- 7 files changed, 282 insertions(+), 144 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index d7916b31d3c7..6e536c268705 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -325,8 +325,9 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): ata, max_num_tokens, world_size, - dp_size, rank, + dp_size, + a.dtype, ) a_chunk = chunk_by_rank(a, rank, world_size).to(device) @@ -459,8 +460,8 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): ata, max_num_tokens, world_size, - dp_size, rank, + dp_size, ) experts = BatchedExperts(a.shape[0]) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f5c0a452a2bf..9e48d84be2f3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -13,7 +13,7 @@ from torch.nn.parameter import UninitializedParameter import vllm.envs as envs -from vllm.config import get_current_vllm_config +from vllm.config import get_current_vllm_config, ParallelConfig from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -50,6 +50,112 @@ MOE_DP_CHUNK_SIZE = 256 +@dataclass +class FusedMoEParallelConfig: + tp_size: int + dp_size: int + ep_size: int + tp_rank: int + dp_rank: int + ep_rank: int + + use_ep: bool # whether to use EP or not + + @property + def use_pplx_kernels(self): + return self.use_ep and has_pplx + + @staticmethod + def make(tp_size_: int, dp_size_: int, + vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": + """ + Determine MoE parallel configuration. Based on the input tp_size_, dp_size_, + ep_size_ and vllm's parallel config, determine what level's of parallelism + to use in the fused moe layer. + + Args: + tp_size_ (int): tp_size passed into the FusedMoE constructor. + dp_size_ (int): dp_size passed into the FusedMoE constructor. + ep_size_ (int): ep_size passed into the FusedMoE constructor. + vllm_parallel_config (ParallelConfig): vllm's parallel config object. + + Examples: + When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, + we simply return the sizes unaltered and the ranks set to 0. + + Expert Parallelism is considered only when either dp_size_ or tp_size_ is non trivial. + + When TP = 2, DP = 1 and EP = False, the configuration on different devices, + - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // legend : {size, rank} + - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} + - Comment : Tensors are sharded across 2 devices. + + When TP = 1, DP = 2 and EP = False, the configuration on different devices, + - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} + - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded across 2 decvices. + + When TP = 2, DP = 2 and EP = False, the configuration on different devices, + - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} + - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} + - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} + - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded across 4 devices. + + When, TP = 2, DP = 1 and EP = True, the configuration on different devices, + - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} + - Comment: The experts are split between the 2 devices. + + When, TP = 1, DP = 2 and EP = True, the configuration on different devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} + - Comment: There are 2 engine instances and the experts are split between the 2 devices. + + When TP = 2, DP = 2 and EP = True, the configuration on different devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} + - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} + - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} + - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} + - Comment: There are 2 engine instances and the experts are split between the 4 devices. + """ + + def flatten_tp_across_dp(dp_rank: int): + tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() + # There are actually dp_size_ * tp_size_ devices. Update tp_size + # and tp_rank so we shard across all devices. + tp_size = dp_size_ * tp_size_ + tp_rank = dp_rank * tp_size_ + tp_rank + return tp_size, tp_rank + + use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel + + dp_size = dp_size_ + dp_rank = get_dp_group().rank_in_group + tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + + if not use_ep: + return FusedMoEParallelConfig(tp_size = tp_size, + tp_rank = tp_rank, + dp_size = dp_size, + dp_rank = dp_rank, + ep_size = 1, + ep_rank = 0, + use_ep = False) + # DP + EP / TP + EP / DP + TP + EP + assert use_ep + # In EP, each device owns a set of experts fully. There is no tensor parallel. + # Update tp_size, tp_rank, ep_size and ep_rank to reflect that. + ep_size = tp_size + ep_rank = tp_rank + return FusedMoEParallelConfig(tp_size = 1, + tp_rank = 0, + dp_size = dp_size, + dp_rank = dp_rank, + ep_size = ep_size, + ep_rank = ep_rank, + use_ep = True) + # Adapted from pplx-kernels tests/all_to_all_utils.py @dataclass class MoEConfig: @@ -58,16 +164,45 @@ class MoEConfig: hidden_dim: int num_local_experts: int - dp_size: int - dp_rank: int - ep_size: int - ep_rank: int + moe_parallel_config: FusedMoEParallelConfig in_dtype: torch.dtype # The activation type. # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels + class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" @@ -299,6 +434,7 @@ def forward_cuda( apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -478,6 +614,57 @@ def determine_expert_map( return (local_num_experts, expert_map) +def construct_dispatch_combine(moe: MoEConfig, + quant_config: Optional[QuantizationConfig]) -> FusedMoEQuantizeDispatchCombine: + + dispatch_combine: FusedMoEQuantizeDispatchCombine = None + if moe.use_pplx_kernels: + logger.info("using pplx dispatch") + max_num_tokens = MOE_DP_CHUNK_SIZE + world_size = moe.ep_size + rank = moe.ep_rank + dp_size= moe.ep_size // moe.dp_size # dp_size actually means TP. + + all_to_all = get_all_to_all( + max_num_tokens=max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=rank, + world_size=world_size, + dp_size= dp_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( + (moe.hidden_dim + moe.block_size - 1) // moe.block_size * + torch.float32.itemsize))) + + dispatch_combine = PplxDispatchCombine( + all_to_all, + max_num_tokens, + world_size, + rank, # just for debugging + dp_size, + moe.in_dtype, + ) + elif True: + logger.info("using standard dispatch") + dispatch_combine = StandardDispatchCombine( + moe.in_dtype, + quant_config.weight_block_size + if quant_config is not None else None, + ) + else: + logger.info("using batched dispatch") + dispatch_combine = BatchedDispatchCombine( + moe.ep_size, + moe.ep_rank, + ) + return dispatch_combine + + class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. @@ -528,21 +715,13 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - # Note: here we guard against accessing the TP and DP groups when - # uninitialized (this happens when testing) - self.tp_size = (tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()) - tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank() - self.dp_size = (dp_size - if dp_size is not None else get_dp_group().world_size) - self.dp_rank = (0 - if self.dp_size == 1 else get_dp_group().rank_in_group) - self.global_num_experts = num_experts - - # Use expert parallelism instead of tensor parallelism? vllm_config = get_current_vllm_config() - use_ep = (vllm_config.parallel_config.enable_expert_parallel - and self.tp_size * self.dp_size > 1) + self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( + tp_size_ = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()), + dp_size_ = (dp_size if dp_size is not None else get_dp_group().world_size), + vllm_parallel_config=vllm_config.parallel_config) + + self.global_num_experts = num_experts # For smuggling this layer into the fused moe custom op self.use_direct_call = self.dp_size == 1 @@ -553,26 +732,15 @@ def __init__( compilation_config.static_forward_context[prefix] = self self.layer_name = prefix - if use_ep: - # Set TP size to 1 to adjust for EP and adjust EP size and rank - # for DP attention. - self.ep_rank = tp_rank + self.tp_size * self.dp_rank - self.tp_rank = 0 - self.ep_size = self.tp_size * self.dp_size - self.tp_size = 1 - + # Determine expert maps + if self.use_ep: self.local_num_experts, self.expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts) else: - # Adjust TP size for DP attention - self.tp_rank = tp_rank + self.tp_size * self.dp_rank - self.ep_rank = 0 - self.tp_size = self.tp_size * self.dp_size - self.ep_size = 1 - self.local_num_experts = self.global_num_experts - self.expert_map = None + self.local_num_experts, self.expert_map = (self.global_num_experts, None) + self.top_k = top_k assert intermediate_size % self.tp_size == 0 @@ -602,11 +770,8 @@ def __init__( experts_per_token=top_k, hidden_dim=hidden_size, num_local_experts=self.local_num_experts, - dp_size=self.dp_size, - dp_rank=self.dp_rank, - ep_size=self.ep_size, - ep_rank=self.ep_rank, - in_dtype=params_dtype, # TODO: is this right? + moe_parallel_config=self.moe_parallel_config, + in_dtype=params_dtype, # TODO: is this right? ) # Note: get_quant_method will look at the layer's local_num_experts @@ -622,7 +787,7 @@ def __init__( assert quant_method is not None self.quant_method = quant_method - dispatch_combine = self._construct_dispatch_combine(moe, quant_config) + dispatch_combine = construct_dispatch_combine(moe, quant_config) success = self.quant_method.set_dispatch_combine(dispatch_combine) @@ -648,56 +813,37 @@ def __init__( self.quant_method.create_weights(layer=self, **moe_quant_params) - # TODO: return Optional? - def _construct_dispatch_combine( - self, - moe: MoEConfig, - quant_config: Optional[QuantizationConfig], - ) -> FusedMoEQuantizeDispatchCombine: - if self.dp_size > 1 and has_pplx: - logger.debug("using pplx dispatch") - max_num_tokens = MOE_DP_CHUNK_SIZE - world_size = moe.ep_size - dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. - rank = moe.ep_rank - - all_to_all = get_all_to_all( - max_num_tokens=max_num_tokens, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk - rank=rank, - world_size=world_size, - dp_size=dp_size, - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - # For blocked per token: set to - # ceil_div(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( - (moe.hidden_dim + moe.block_size - 1) // moe.block_size * - torch.float32.itemsize))) - - return PplxDispatchCombine( - all_to_all, - max_num_tokens, - world_size, - dp_size, - rank, - moe.in_dtype, - ) - elif True: - logger.debug("using standard dispatch") - return StandardDispatchCombine( - moe.in_dtype, - quant_config.weight_block_size - if quant_config is not None else None, - ) - else: - logger.debug("using batched dispatch") - return BatchedDispatchCombine( - moe.ep_size, - moe.ep_rank, - ) + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, @@ -1020,6 +1166,19 @@ def naive_multicast(self, x: torch.Tensor, def invalid_pplx(self, hidden_states: torch.Tensor) -> bool: return has_pplx and hidden_states.shape[0] < self.dp_size + def must_reduce_shared_outputs(self) -> bool: + return self.dp_size > 1 and self.use_ep and has_pplx + + def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): + """ + The pplx combine kernel reduce across GPU ranks by default. The pplx kernels are + used when EP is enabled. In that case, this function is a no-op. + """ + if self.dp_size > 1 and self.use_ep and has_pplx: + return final_hidden_states + else: + return tensor_model_parallel_all_reduce(final_hidden_states) + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): if self.use_direct_call or self.invalid_pplx(hidden_states): @@ -1037,13 +1196,6 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - # TODO: still may be needed for non-pplx, put into dispatcher class. - if False: - hidden_states = self.naive_multicast( - hidden_states, cu_tokens_across_dp_this_iter) - router_logits = self.naive_multicast( - router_logits, cu_tokens_across_dp_this_iter) - # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1062,33 +1214,12 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): activation=self.activation, ) - # TODO: needed for non-pplx? - if False and self.dp_size > 1: - if self.dp_rank == 0: - start = 0 - else: - start = cu_tokens_across_dp_this_iter[self.dp_rank - 1] - - end = cu_tokens_across_dp_this_iter[self.dp_rank] - - all_hidden_states = get_dp_group().all_reduce( - final_hidden_states) - final_hidden_states = all_hidden_states[start:end, :] - - # TODO: needed for non-pplx? - if False and self.reduce_results and (self.tp_size > 1 - or self.ep_size > 1): - # Default set to False. (May have to add shared expert outputs.) - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - if not skip_result_store: full_final_hidden_states[chunk_start:chunk_end, :].copy_( final_hidden_states) - max_tokens_across_dp = get_forward_context( - ).dp_metadata.max_tokens_across_dp - moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size + max_tokens_across_dp = get_forward_context().dp_metadata.max_tokens_across_dp + moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE num_tokens = full_hidden_states.size(0) for chunk_start_ in range(0, max_tokens_across_dp, @@ -1109,9 +1240,10 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None + if self.dp_size > 1 and self.use_ep and has_pplx: + return self.forward_impl_chunked(hidden_states, router_logits) - # TODO: still may be needed for non-pplx - if False and self.dp_size > 1: + if self.dp_size > 1: ctx = get_forward_context() cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu @@ -1139,8 +1271,7 @@ def forward_impl(self, hidden_states: torch.Tensor, apply_router_weight_on_input=self.apply_router_weight_on_input, ) - # TODO: needed for non-pplx? - if False and self.dp_size > 1: + if self.dp_size > 1: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ self.dp_rank - 1] end = cu_tokens_across_dp_cpu[self.dp_rank] @@ -1148,9 +1279,8 @@ def forward_impl(self, hidden_states: torch.Tensor, all_hidden_states = get_dp_group().all_reduce(final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] - # TODO: needed for non-pplx? - if False and self.reduce_results and (self.tp_size > 1 - or self.ep_size > 1): + if self.reduce_results and (self.tp_size > 1 + or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -1202,7 +1332,7 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, self = forward_context.no_compile_layers[layer_name] assert self.quant_method is not None - return self.forward_impl_chunked(hidden_states, router_logits) + return self.forward_impl(hidden_states, router_logits) def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 4c00edd0b3d8..99049175250a 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -23,8 +23,8 @@ def __init__(self, a2a: pplx.AllToAll, max_num_tokens: int, world_size: int, - dp_size: int, rank: int, + dp_size: int, quant_dtype: Optional[torch.dtype] = None, block_shape: Optional[List[int]] = None): super().__init__() @@ -33,8 +33,8 @@ def __init__(self, self.block_shape = block_shape self.max_num_tokens = max_num_tokens self.world_size = world_size - self.dp_size = dp_size self.rank = rank + self.dp_size = dp_size self.quant_dtype = quant_dtype def dispatch( @@ -105,7 +105,7 @@ def dispatch( ) # This argument is optional, defaults to indices.shape[0] - bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) + bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=a1q.device) # TODO: optimize this? indices = rank_topk_ids.to(dtype=torch.uint32) @@ -119,6 +119,7 @@ def dispatch( indices=indices, bound_m=bound_m, ) + return expert_x, expert_x_scale, expert_num_tokens def combine( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 350e3a592178..2e12095d3ae9 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -32,8 +32,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -143,7 +142,13 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=False, + # When just tensor-parallel is used, it isn't required + # to reduce the shared_output result. Instead we reduce + # at the end of the forward pass. + # With EP and the pplx kernels - this is no longer viable + # as all GPU ranks in DP, produce the complete set of hidden_states. + # Therefore reduce the shared experts early. + reduce_results=self.experts.must_reduce_shared_outputs(), prefix=f"{prefix}.shared_experts", ) @@ -154,6 +159,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) + if hidden_states.dtype != torch.float16: final_hidden_states = self.experts( hidden_states=hidden_states, @@ -172,10 +178,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = final_hidden_states + shared_output \ * (1. / self.routed_scaling_factor) - # TODO: check if needed for non-pplx? - if False and self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) @@ -461,6 +465,7 @@ def __init__( o_proj=self.o_proj, ) + self.prefix = prefix self.debug_layer_idx = int(self.prefix.split(".")[-2]) @@ -482,7 +487,6 @@ def forward( k_pe, output_shape=hidden_states.shape) - class DeepseekV2DecoderLayer(nn.Module): def __init__( @@ -560,6 +564,7 @@ def forward( else: hidden_states, residual = self.input_layernorm( hidden_states, residual) + hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -578,6 +583,7 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) + hidden_states = self.mlp(hidden_states) if isinstance(self.mlp, diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 0fdc30f36f9b..68e427d272c6 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -102,7 +102,7 @@ def forward(self, hidden_states): experts_out = routed_out + shared_out if self.tp_size > 1: - experts_out = tensor_model_parallel_all_reduce(experts_out) + experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(experts_out) return experts_out diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 47d90919ed8f..aca0d658d882 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -156,7 +156,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states) return final_hidden_states.view(orig_shape) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 97acbaa2ac34..904dc2c452f1 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -137,7 +137,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits=router_logits) final_hidden_states = final_hidden_states if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states) return final_hidden_states.view(orig_shape) From ba8f478146e7f5c6a0bb65cbb7a67f24b1db06bf Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 7 May 2025 02:30:59 -0400 Subject: [PATCH 171/171] zero out attn outputs during profile run Signed-off-by: Varun Sundar Rabindranath --- vllm/v1/attention/backends/mla/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index fd3be901f4c3..8150d8900bb5 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -901,7 +901,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) num_actual_toks = attn_metadata.num_actual_tokens