Skip to content

Commit c3c2221

Browse files
authored
[Feat]support dynamic quantization in allgather (#2841)
### What this PR does / why we need it? [Feat]support dynamic quantization in allgather ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: main - vLLM main: vllm-project/vllm@5931b7e Signed-off-by: withHades <[email protected]> Signed-off-by: WithHades <[email protected]>
1 parent 07c5866 commit c3c2221

File tree

4 files changed

+112
-4
lines changed

4 files changed

+112
-4
lines changed

tests/e2e/multicard/test_qwen3_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def test_models_distributed_Qwen3_MOE_W8A8():
6666
max_model_len=8192,
6767
tensor_parallel_size=2,
6868
quantization="ascend",
69-
enforce_eager=True,
7069
) as vllm_model:
7170
vllm_model.generate_greedy(example_prompts, max_tokens)
7271

tests/e2e/singlecard/ops/test_fused_moe.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from vllm.model_executor.layers.activation import SiluAndMul
3030

3131
from vllm_ascend.ops.moe.experts_selector import select_experts
32+
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
3233
from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather
3334

3435
NUM_EXPERTS = [8, 64]
@@ -165,6 +166,87 @@ def test_token_dispatcher_with_all_gather(
165166
torch.npu.reset_peak_memory_stats()
166167

167168

169+
@pytest.mark.parametrize("m", [1, 33, 64])
170+
@pytest.mark.parametrize("n", [128, 1024, 2048])
171+
@pytest.mark.parametrize("k", [128, 511, 1024])
172+
@pytest.mark.parametrize("e", NUM_EXPERTS)
173+
@pytest.mark.parametrize("topk", TOP_KS)
174+
@pytest.mark.parametrize("ep_size", EP_SIZE)
175+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
176+
@pytest.mark.parametrize("device", DEVICE)
177+
def test_token_dispatcher_with_all_gather_quant(
178+
m: int,
179+
n: int,
180+
k: int,
181+
e: int,
182+
topk: int,
183+
ep_size: int,
184+
dtype: torch.dtype,
185+
device: str,
186+
):
187+
context_mock = MagicMock()
188+
context_mock.fused_moe_state = 0
189+
with patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context",
190+
return_value=context_mock):
191+
a = torch.randn((m, k), device=device, dtype=dtype) / 10
192+
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
193+
w1_scale = torch.empty((e, 2 * n), device=device, dtype=dtype)
194+
w2 = torch.randn((e, n, k), device=device, dtype=torch.int8)
195+
w2_scale = torch.empty((e, k), device=device, dtype=dtype)
196+
197+
score = torch.randn((m, e), device=device, dtype=dtype)
198+
expert_map = None
199+
local_e = e
200+
201+
score = torch.softmax(score, dim=-1, dtype=dtype)
202+
topk_weights, topk_ids = torch.topk(score, topk)
203+
topk_ids = topk_ids.to(torch.int32)
204+
row_idx = (torch.arange(
205+
0,
206+
m * topk,
207+
device=device,
208+
dtype=torch.int32,
209+
).view(topk, -1).permute(1, 0).contiguous())
210+
211+
dispatcher_kwargs = {
212+
"num_experts": e,
213+
"top_k": topk,
214+
"num_local_experts": local_e,
215+
}
216+
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
217+
218+
apply_router_weight_on_input = False
219+
dispatch_output = dispatcher.token_dispatch(
220+
hidden_states=a,
221+
topk_weights=topk_weights,
222+
topk_ids=topk_ids,
223+
row_idx=row_idx,
224+
expert_map=expert_map,
225+
apply_router_weight_on_input=apply_router_weight_on_input,
226+
with_quant=True)
227+
228+
sorted_hidden_states = dispatch_output["hidden_states"]
229+
group_list = dispatch_output["group_list"]
230+
group_list_type = dispatch_output.get("group_list_type", 1)
231+
dynamic_scale = dispatch_output["dynamic_scale"]
232+
233+
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
234+
w1=w1,
235+
w1_scale=w1_scale,
236+
w2=w2,
237+
w2_scale=w2_scale,
238+
group_list=group_list,
239+
group_list_type=group_list_type,
240+
dynamic_scale=dynamic_scale,
241+
with_quant=True)
242+
combined_output = dispatcher.token_combine(hidden_states=expert_output,
243+
bias=None)
244+
assert combined_output.shape == (m, k)
245+
gc.collect()
246+
torch.npu.empty_cache()
247+
torch.npu.reset_peak_memory_stats()
248+
249+
168250
@pytest.mark.parametrize("m", [1, 33, 64])
169251
@pytest.mark.parametrize("n", [128, 1024, 2048])
170252
@pytest.mark.parametrize("e", NUM_EXPERTS)

tests/ut/ops/test_token_dispatcher.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def test_token_dispatch_with_expert_map(self):
221221

222222
self.assertEqual(results["group_list_type"], 1)
223223

224-
def test_token_dispatch_with_quant(self):
224+
def test_token_dispatch_without_quant(self):
225225
kwargs = {
226226
"apply_router_weight_on_input": False,
227227
"top_k": 2,
@@ -241,6 +241,32 @@ def test_token_dispatch_with_quant(self):
241241

242242
self.assertEqual(results["group_list_type"], 1)
243243

244+
def test_token_dispatch_with_quant(self):
245+
kwargs = {
246+
"apply_router_weight_on_input": False,
247+
"top_k": 2,
248+
"max_num_tokens": 100,
249+
"ep_size": 2,
250+
"num_experts": 128,
251+
}
252+
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
253+
254+
hidden_states = torch.randn(3, 128)
255+
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
256+
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
257+
258+
results = self.dispatcher_quant.token_dispatch(hidden_states,
259+
topk_weights,
260+
topk_ids,
261+
self.row_idx,
262+
None,
263+
with_quant=True)
264+
265+
self.assertIsNotNone(results["hidden_states"])
266+
self.assertIsNotNone(results["group_list"])
267+
self.assertIsNotNone(results["dynamic_scale"])
268+
self.assertEqual(results["group_list_type"], 1)
269+
244270
def test_token_combine_with_expert_map(self):
245271
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
246272
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])

vllm_ascend/ops/moe/token_dispatcher.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def token_dispatch(self,
367367
last_expert_idx = self.num_experts_local
368368
global_num_experts = self.num_experts_local
369369

370-
sorted_hidden_states, self.expanded_row_idx, expert_tokens, _ = (
370+
sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = (
371371
torch_npu.npu_moe_init_routing_v2(
372372
hidden_states,
373373
topk_ids,
@@ -376,14 +376,15 @@ def token_dispatch(self,
376376
expert_tokens_num_type=1,
377377
expert_tokens_num_flag=True,
378378
active_expert_range=[first_expert_idx, last_expert_idx],
379-
quant_mode=-1,
379+
quant_mode=1 if self.with_quant else -1,
380380
))
381381
expert_tokens = expert_tokens.to(torch.int64)
382382
group_list_type = 1 # `count` mode
383383
return {
384384
"group_list_type": group_list_type,
385385
"hidden_states": sorted_hidden_states,
386386
"group_list": expert_tokens,
387+
"dynamic_scale": pertoken_scale if self.with_quant else None,
387388
}
388389

389390
def token_combine(self,

0 commit comments

Comments
 (0)