Skip to content

Commit 82c50d6

Browse files
committed
Add extra checking to torchair_graph_config
Signed-off-by: 22dimensions <[email protected]>
1 parent 2008152 commit 82c50d6

File tree

5 files changed

+100
-6
lines changed

5 files changed

+100
-6
lines changed

examples/offline_data_parallel.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,16 @@
5454
--master-port=13345
5555
"""
5656

57-
import os
58-
from time import sleep
5957
import contextlib
6058
import gc
59+
import os
60+
from time import sleep
6161

6262
import torch
63-
6463
from vllm import LLM, SamplingParams
65-
from vllm.utils import get_open_port
6664
from vllm.distributed.parallel_state import ( # noqa E402
6765
destroy_distributed_environment, destroy_model_parallel)
66+
from vllm.utils import get_open_port
6867

6968
os.environ["VLLM_USE_MODELSCOPE"] = "True"
7069
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

tests/ut/test_ascend_config.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,71 @@ def test_check_torchair_supported(self):
236236
for model_type, expected_output in test_cases:
237237
self.assertEqual(_check_torchair_supported(model_type),
238238
expected_output)
239+
240+
@_clean_up_ascend_config
241+
def test_ascend_config_load_error(self):
242+
test_vllm_config = VllmConfig()
243+
# graph_batch_sizes should be list.
244+
with self.assertRaises(TypeError):
245+
test_vllm_config.additional_config = {
246+
"torchair_graph_config": {
247+
"graph_batch_sizes": "fake_size",
248+
},
249+
"refresh": True
250+
}
251+
init_ascend_config(test_vllm_config)
252+
253+
# use_cached_graph should not be enabled without torchair graph mode
254+
with self.assertRaises(RuntimeError):
255+
test_vllm_config.additional_config = {
256+
"torchair_graph_config": {
257+
"enabled": False,
258+
"use_cached_graph": True,
259+
},
260+
"refresh": True
261+
}
262+
init_ascend_config(test_vllm_config)
263+
264+
# graph_batch_sizes_init should not be enabled without torchair graph mode
265+
with self.assertRaises(RuntimeError):
266+
test_vllm_config.additional_config = {
267+
"torchair_graph_config": {
268+
"enabled": False,
269+
"graph_batch_sizes_init": True,
270+
},
271+
"refresh": True
272+
}
273+
init_ascend_config(test_vllm_config)
274+
275+
# enable_multistream_mla should not be enabled without torchair graph mode
276+
with self.assertRaises(RuntimeError):
277+
test_vllm_config.additional_config = {
278+
"torchair_graph_config": {
279+
"enabled": False,
280+
"enable_multistream_mla": True,
281+
},
282+
"refresh": True
283+
}
284+
init_ascend_config(test_vllm_config)
285+
286+
# enable_multistream_moe should not be enabled without torchair graph mode
287+
with self.assertRaises(RuntimeError):
288+
test_vllm_config.additional_config = {
289+
"torchair_graph_config": {
290+
"enabled": False,
291+
"enable_multistream_moe": True,
292+
},
293+
"refresh": True
294+
}
295+
init_ascend_config(test_vllm_config)
296+
297+
# enable_kv_nz should not be enabled without torchair graph mode
298+
with self.assertRaises(RuntimeError):
299+
test_vllm_config.additional_config = {
300+
"torchair_graph_config": {
301+
"enabled": False,
302+
"enable_kv_nz": True,
303+
},
304+
"refresh": True
305+
}
306+
init_ascend_config(test_vllm_config)

vllm_ascend/ascend_config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,31 @@ def __init__(self, torchair_graph_config):
7676
raise ValueError(
7777
"graph_batch_sizes_init is only valid when graph_batch_sizes is empty"
7878
)
79+
if not self.enabled:
80+
if self.use_cached_graph:
81+
raise RuntimeError(
82+
"use_cached_graph is valid only when Torchair graph mode is enabled"
83+
)
84+
if self.graph_batch_sizes:
85+
raise RuntimeError(
86+
"graph_batch_sizes is valid only when Torchair graph mode is enabled"
87+
)
88+
if self.graph_batch_sizes_init:
89+
raise RuntimeError(
90+
"graph_batch_sizes_init is valid only when Torchair graph mode is enabled"
91+
)
92+
if self.enable_multistream_mla:
93+
raise RuntimeError(
94+
"enable_multistream_mla is valid only when Torchair graph mode is enabled"
95+
)
96+
if self.enable_multistream_moe:
97+
raise RuntimeError(
98+
"enable_multistream_moe is valid only when Torchair graph mode is enabled"
99+
)
100+
if self.enable_kv_nz:
101+
raise RuntimeError(
102+
"enable_kv_nz is valid only when Torchair graph mode is enabled"
103+
)
79104

80105

81106
class AscendSchedulerConfig:

vllm_ascend/models/deepseek_v2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,8 @@ def __init__(
313313
ascend_config = get_ascend_config()
314314
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
315315
self.enable_multistream_moe = \
316-
ascend_config.torchair_graph_config.enable_multistream_moe
316+
ascend_config.torchair_graph_config.enable_multistream_moe and \
317+
self.torchair_graph_enabled
317318

318319
self.gate = ReplicatedLinear(config.hidden_size,
319320
config.n_routed_experts,

vllm_ascend/ops/fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1253,7 +1253,8 @@ def __init__(
12531253

12541254
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
12551255
self.enable_multistream_moe = \
1256-
ascend_config.torchair_graph_config.enable_multistream_moe
1256+
ascend_config.torchair_graph_config.enable_multistream_moe and \
1257+
self.torchair_graph_enabled
12571258

12581259
if self.scoring_func != "softmax" and not self.use_grouped_topk:
12591260
raise ValueError("Only softmax scoring function is supported for "

0 commit comments

Comments
 (0)