|
29 | 29 | from vllm.model_executor.layers.activation import SiluAndMul
|
30 | 30 |
|
31 | 31 | from vllm_ascend.ops.moe.experts_selector import select_experts
|
| 32 | +from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp |
32 | 33 | from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather
|
33 | 34 |
|
34 | 35 | NUM_EXPERTS = [8, 64]
|
@@ -165,6 +166,87 @@ def test_token_dispatcher_with_all_gather(
|
165 | 166 | torch.npu.reset_peak_memory_stats()
|
166 | 167 |
|
167 | 168 |
|
| 169 | +@pytest.mark.parametrize("m", [1, 33, 64]) |
| 170 | +@pytest.mark.parametrize("n", [128, 1024, 2048]) |
| 171 | +@pytest.mark.parametrize("k", [128, 511, 1024]) |
| 172 | +@pytest.mark.parametrize("e", NUM_EXPERTS) |
| 173 | +@pytest.mark.parametrize("topk", TOP_KS) |
| 174 | +@pytest.mark.parametrize("ep_size", EP_SIZE) |
| 175 | +@pytest.mark.parametrize("dtype", [torch.bfloat16]) |
| 176 | +@pytest.mark.parametrize("device", DEVICE) |
| 177 | +def test_token_dispatcher_with_all_gather_quant( |
| 178 | + m: int, |
| 179 | + n: int, |
| 180 | + k: int, |
| 181 | + e: int, |
| 182 | + topk: int, |
| 183 | + ep_size: int, |
| 184 | + dtype: torch.dtype, |
| 185 | + device: str, |
| 186 | +): |
| 187 | + context_mock = MagicMock() |
| 188 | + context_mock.fused_moe_state = 0 |
| 189 | + with patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context", |
| 190 | + return_value=context_mock): |
| 191 | + a = torch.randn((m, k), device=device, dtype=dtype) / 10 |
| 192 | + w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8) |
| 193 | + w1_scale = torch.empty((e, 2 * n), device=device, dtype=dtype) |
| 194 | + w2 = torch.randn((e, n, k), device=device, dtype=torch.int8) |
| 195 | + w2_scale = torch.empty((e, k), device=device, dtype=dtype) |
| 196 | + |
| 197 | + score = torch.randn((m, e), device=device, dtype=dtype) |
| 198 | + expert_map = None |
| 199 | + local_e = e |
| 200 | + |
| 201 | + score = torch.softmax(score, dim=-1, dtype=dtype) |
| 202 | + topk_weights, topk_ids = torch.topk(score, topk) |
| 203 | + topk_ids = topk_ids.to(torch.int32) |
| 204 | + row_idx = (torch.arange( |
| 205 | + 0, |
| 206 | + m * topk, |
| 207 | + device=device, |
| 208 | + dtype=torch.int32, |
| 209 | + ).view(topk, -1).permute(1, 0).contiguous()) |
| 210 | + |
| 211 | + dispatcher_kwargs = { |
| 212 | + "num_experts": e, |
| 213 | + "top_k": topk, |
| 214 | + "num_local_experts": local_e, |
| 215 | + } |
| 216 | + dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs) |
| 217 | + |
| 218 | + apply_router_weight_on_input = False |
| 219 | + dispatch_output = dispatcher.token_dispatch( |
| 220 | + hidden_states=a, |
| 221 | + topk_weights=topk_weights, |
| 222 | + topk_ids=topk_ids, |
| 223 | + row_idx=row_idx, |
| 224 | + expert_map=expert_map, |
| 225 | + apply_router_weight_on_input=apply_router_weight_on_input, |
| 226 | + with_quant=True) |
| 227 | + |
| 228 | + sorted_hidden_states = dispatch_output["hidden_states"] |
| 229 | + group_list = dispatch_output["group_list"] |
| 230 | + group_list_type = dispatch_output.get("group_list_type", 1) |
| 231 | + dynamic_scale = dispatch_output["dynamic_scale"] |
| 232 | + |
| 233 | + expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states, |
| 234 | + w1=w1, |
| 235 | + w1_scale=w1_scale, |
| 236 | + w2=w2, |
| 237 | + w2_scale=w2_scale, |
| 238 | + group_list=group_list, |
| 239 | + group_list_type=group_list_type, |
| 240 | + dynamic_scale=dynamic_scale, |
| 241 | + with_quant=True) |
| 242 | + combined_output = dispatcher.token_combine(hidden_states=expert_output, |
| 243 | + bias=None) |
| 244 | + assert combined_output.shape == (m, k) |
| 245 | + gc.collect() |
| 246 | + torch.npu.empty_cache() |
| 247 | + torch.npu.reset_peak_memory_stats() |
| 248 | + |
| 249 | + |
168 | 250 | @pytest.mark.parametrize("m", [1, 33, 64])
|
169 | 251 | @pytest.mark.parametrize("n", [128, 1024, 2048])
|
170 | 252 | @pytest.mark.parametrize("e", NUM_EXPERTS)
|
|
0 commit comments