Skip to content

Commit d145a4d

Browse files
[Fixbug] Fix soc_version for 310p
Signed-off-by: hfadzxy <[email protected]>
1 parent 2693196 commit d145a4d

File tree

3 files changed

+92
-1
lines changed

3 files changed

+92
-1
lines changed

tests/ut/test_utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,91 @@ def test_register_ascend_customop(self, mock_ascend_rmsnorm,
311311
# should not register_oot again, thus only called three in this ut
312312
self.assertEqual(mock_customop.register_oot.call_count, 12)
313313

314+
def test_nd_to_nz_spec(self):
315+
mask_tensor = torch.ones(32, 64, dtype=torch.bool)
316+
output = utils.nd_to_nz_spec(mask_tensor)
317+
self.assertEqual(output.shape, (1, 4, 32, 16)) # 64/16=4, 32->32
318+
319+
mask_tensor = torch.ones(30, 62, dtype=torch.bool)
320+
output = utils.nd_to_nz_spec(mask_tensor)
321+
self.assertEqual(output.shape, (1, 4, 32, 16)) # 62->64, 30->32
322+
323+
mask_tensor = torch.ones(16, 16, dtype=torch.bool)
324+
output = utils.nd_to_nz_spec(mask_tensor)
325+
self.assertTrue(torch.all(output[0, 0, :16, :16] == 1))
326+
self.assertTrue(torch.all(output[0, 0, 16:, :] == 0))
327+
self.assertTrue(torch.all(output[0, 1:, :, :] == 0))
328+
329+
def test_dispose_tensor(self):
330+
x = torch.ones(10, 10)
331+
original_data_ptr = x.data_ptr()
332+
utils.dispose_tensor(x)
333+
self.assertEqual(x.numel(), 0)
334+
self.assertNotEqual(x.data_ptr(), original_data_ptr)
335+
336+
def test_npu_prefetch(self):
337+
input_tensor = torch.ones(10, device='npu')
338+
dependency = torch.ones(5, device='npu')
339+
utils.npu_prefetch(input_tensor, dependency, enabled=True)
340+
341+
utils.npu_prefetch(input_tensor, dependency, enabled=False)
342+
343+
344+
def test_init_ascend_soc_version(self):
345+
test_cases = [
346+
(220, utils.AscendSocVersion.A2),
347+
(225, utils.AscendSocVersion.A2),
348+
(250, utils.AscendSocVersion.A3),
349+
(255, utils.AscendSocVersion.A3),
350+
(202, utils.AscendSocVersion.P3),
351+
(999, utils.AscendSocVersion.UNDEFINED),
352+
]
353+
354+
for soc_version, expected in test_cases:
355+
with self.subTest(soc_version=soc_version):
356+
with mock.patch('torch_npu.npu.get_soc_version', return_value=soc_version):
357+
utils._ascend_soc_version = None # Reset
358+
utils.init_ascend_soc_version()
359+
result = utils.get_ascend_soc_version()
360+
self.assertEqual(result, expected)
361+
362+
def test_get_ascend_soc_version(self):
363+
utils._ascend_soc_version = None
364+
with self.assertRaises(AssertionError):
365+
utils.get_ascend_soc_version()
366+
367+
utils._ascend_soc_version = utils.AscendSocVersion.A2
368+
self.assertEqual(utils.get_ascend_soc_version(), utils.AscendSocVersion.A2)
369+
370+
def test_lmhead_tp_enable(self):
371+
with mock.patch('vllm_ascend.utils.get_ascend_config') as mock_config:
372+
mock_config.return_value.lmhead_tensor_parallel_size = 2
373+
self.assertTrue(utils.lmhead_tp_enable())
374+
375+
mock_config.return_value.lmhead_tensor_parallel_size = None
376+
self.assertFalse(utils.lmhead_tp_enable())
377+
378+
def test_oproj_tp_enable(self):
379+
with mock.patch('vllm_ascend.utils.get_ascend_config') as mock_config:
380+
mock_config.return_value.oproj_tensor_parallel_size = 2
381+
self.assertTrue(utils.oproj_tp_enable())
382+
383+
mock_config.return_value.oproj_tensor_parallel_size = None
384+
self.assertFalse(utils.oproj_tp_enable())
385+
386+
def test_mlp_tp_enable(self):
387+
with mock.patch.dict(os.environ, {'VLLM_ASCEND_ENABLE_MLP_OPTIMIZE': '1'}):
388+
self.assertTrue(utils.mlp_tp_enable())
389+
390+
with mock.patch.dict(os.environ, {'VLLM_ASCEND_ENABLE_MLP_OPTIMIZE': '0'}):
391+
self.assertFalse(utils.mlp_tp_enable())
392+
393+
def test_matmul_allreduce_enable(self):
394+
with mock.patch.dict(os.environ, {'VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE': '1'}):
395+
self.assertTrue(utils.matmul_allreduce_enable())
396+
397+
with mock.patch.dict(os.environ, {'VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE': '0'}):
398+
self.assertFalse(utils.matmul_allreduce_enable())
314399

315400
class TestProfileExecuteDuration(TestBase):
316401

vllm_ascend/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
ASCEND_QUANTIZATION_METHOD = "ascend"
5151
SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
52+
ASCEND_310P_SOC_VERSION = 202
5253

5354
ACL_FORMAT_FRACTAL_ND = 2
5455
ACL_FORMAT_FRACTAL_NZ = 29
@@ -535,7 +536,8 @@ def register_ascend_customop():
535536
class AscendSocVersion(Enum):
536537
A2 = 0
537538
A3 = 1
538-
UNDEFINED = 2
539+
P3 = 2
540+
UNDEFINED = 3
539541

540542

541543
_ascend_soc_version = None
@@ -548,6 +550,8 @@ def init_ascend_soc_version():
548550
_ascend_soc_version = AscendSocVersion.A2
549551
elif 250 <= soc_version <= 255:
550552
_ascend_soc_version = AscendSocVersion.A3
553+
elif soc_version == ASCEND_310P_SOC_VERSION:
554+
_ascend_soc_version = AscendSocVersion.P3
551555
else:
552556
_ascend_soc_version = AscendSocVersion.UNDEFINED
553557

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,6 +1465,8 @@ def _select_moe_comm_method(self, num_tokens: int) -> str:
14651465
moe_comm_method = "mc2"
14661466
else:
14671467
moe_comm_method = "allgather"
1468+
elif soc_version in {AscendSocVersion.P3}:
1469+
moe_comm_method = "allgather"
14681470
elif soc_version in {AscendSocVersion.A3}:
14691471
moe_comm_method = "mc2" if num_tokens <= self.mc2_tokens_capacity else "alltoall"
14701472
else:

0 commit comments

Comments
 (0)