@@ -311,6 +311,91 @@ def test_register_ascend_customop(self, mock_ascend_rmsnorm,
311
311
# should not register_oot again, thus only called three in this ut
312
312
self .assertEqual (mock_customop .register_oot .call_count , 12 )
313
313
314
+ def test_nd_to_nz_spec (self ):
315
+ mask_tensor = torch .ones (32 , 64 , dtype = torch .bool )
316
+ output = utils .nd_to_nz_spec (mask_tensor )
317
+ self .assertEqual (output .shape , (1 , 4 , 32 , 16 )) # 64/16=4, 32->32
318
+
319
+ mask_tensor = torch .ones (30 , 62 , dtype = torch .bool )
320
+ output = utils .nd_to_nz_spec (mask_tensor )
321
+ self .assertEqual (output .shape , (1 , 4 , 32 , 16 )) # 62->64, 30->32
322
+
323
+ mask_tensor = torch .ones (16 , 16 , dtype = torch .bool )
324
+ output = utils .nd_to_nz_spec (mask_tensor )
325
+ self .assertTrue (torch .all (output [0 , 0 , :16 , :16 ] == 1 ))
326
+ self .assertTrue (torch .all (output [0 , 0 , 16 :, :] == 0 ))
327
+ self .assertTrue (torch .all (output [0 , 1 :, :, :] == 0 ))
328
+
329
+ def test_dispose_tensor (self ):
330
+ x = torch .ones (10 , 10 )
331
+ original_data_ptr = x .data_ptr ()
332
+ utils .dispose_tensor (x )
333
+ self .assertEqual (x .numel (), 0 )
334
+ self .assertNotEqual (x .data_ptr (), original_data_ptr )
335
+
336
+ def test_npu_prefetch (self ):
337
+ input_tensor = torch .ones (10 , device = 'npu' )
338
+ dependency = torch .ones (5 , device = 'npu' )
339
+ utils .npu_prefetch (input_tensor , dependency , enabled = True )
340
+
341
+ utils .npu_prefetch (input_tensor , dependency , enabled = False )
342
+
343
+
344
+ def test_init_ascend_soc_version (self ):
345
+ test_cases = [
346
+ (220 , utils .AscendSocVersion .A2 ),
347
+ (225 , utils .AscendSocVersion .A2 ),
348
+ (250 , utils .AscendSocVersion .A3 ),
349
+ (255 , utils .AscendSocVersion .A3 ),
350
+ (202 , utils .AscendSocVersion .P3 ),
351
+ (999 , utils .AscendSocVersion .UNDEFINED ),
352
+ ]
353
+
354
+ for soc_version , expected in test_cases :
355
+ with self .subTest (soc_version = soc_version ):
356
+ with mock .patch ('torch_npu.npu.get_soc_version' , return_value = soc_version ):
357
+ utils ._ascend_soc_version = None # Reset
358
+ utils .init_ascend_soc_version ()
359
+ result = utils .get_ascend_soc_version ()
360
+ self .assertEqual (result , expected )
361
+
362
+ def test_get_ascend_soc_version (self ):
363
+ utils ._ascend_soc_version = None
364
+ with self .assertRaises (AssertionError ):
365
+ utils .get_ascend_soc_version ()
366
+
367
+ utils ._ascend_soc_version = utils .AscendSocVersion .A2
368
+ self .assertEqual (utils .get_ascend_soc_version (), utils .AscendSocVersion .A2 )
369
+
370
+ def test_lmhead_tp_enable (self ):
371
+ with mock .patch ('vllm_ascend.utils.get_ascend_config' ) as mock_config :
372
+ mock_config .return_value .lmhead_tensor_parallel_size = 2
373
+ self .assertTrue (utils .lmhead_tp_enable ())
374
+
375
+ mock_config .return_value .lmhead_tensor_parallel_size = None
376
+ self .assertFalse (utils .lmhead_tp_enable ())
377
+
378
+ def test_oproj_tp_enable (self ):
379
+ with mock .patch ('vllm_ascend.utils.get_ascend_config' ) as mock_config :
380
+ mock_config .return_value .oproj_tensor_parallel_size = 2
381
+ self .assertTrue (utils .oproj_tp_enable ())
382
+
383
+ mock_config .return_value .oproj_tensor_parallel_size = None
384
+ self .assertFalse (utils .oproj_tp_enable ())
385
+
386
+ def test_mlp_tp_enable (self ):
387
+ with mock .patch .dict (os .environ , {'VLLM_ASCEND_ENABLE_MLP_OPTIMIZE' : '1' }):
388
+ self .assertTrue (utils .mlp_tp_enable ())
389
+
390
+ with mock .patch .dict (os .environ , {'VLLM_ASCEND_ENABLE_MLP_OPTIMIZE' : '0' }):
391
+ self .assertFalse (utils .mlp_tp_enable ())
392
+
393
+ def test_matmul_allreduce_enable (self ):
394
+ with mock .patch .dict (os .environ , {'VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE' : '1' }):
395
+ self .assertTrue (utils .matmul_allreduce_enable ())
396
+
397
+ with mock .patch .dict (os .environ , {'VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE' : '0' }):
398
+ self .assertFalse (utils .matmul_allreduce_enable ())
314
399
315
400
class TestProfileExecuteDuration (TestBase ):
316
401
0 commit comments