Skip to content

Commit 08507f5

Browse files
committed
Add extra checking to torchair_graph_config
Signed-off-by: 22dimensions <[email protected]>
1 parent 2008152 commit 08507f5

File tree

5 files changed

+119
-6
lines changed

5 files changed

+119
-6
lines changed

examples/offline_data_parallel.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,16 @@
5454
--master-port=13345
5555
"""
5656

57-
import os
58-
from time import sleep
5957
import contextlib
6058
import gc
59+
import os
60+
from time import sleep
6161

6262
import torch
63-
6463
from vllm import LLM, SamplingParams
65-
from vllm.utils import get_open_port
6664
from vllm.distributed.parallel_state import ( # noqa E402
6765
destroy_distributed_environment, destroy_model_parallel)
66+
from vllm.utils import get_open_port
6867

6968
os.environ["VLLM_USE_MODELSCOPE"] = "True"
7069
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

tests/ut/test_ascend_config.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,90 @@ def test_check_torchair_supported(self):
236236
for model_type, expected_output in test_cases:
237237
self.assertEqual(_check_torchair_supported(model_type),
238238
expected_output)
239+
240+
@_clean_up_ascend_config
241+
def test_ascend_config_load_error(self):
242+
test_vllm_config = VllmConfig()
243+
# graph_batch_sizes should be list.
244+
with self.assertRaises(TypeError):
245+
test_vllm_config.additional_config = {
246+
"torchair_graph_config": {
247+
"graph_batch_sizes": "fake_size",
248+
},
249+
"refresh": True
250+
}
251+
init_ascend_config(test_vllm_config)
252+
253+
# graph_batch_sizes_init should not be True when graph_batch_sizes is not empty.
254+
with self.assertRaises(ValueError):
255+
test_vllm_config.additional_config = {
256+
"torchair_graph_config": {
257+
"graph_batch_sizes": [1, 2, 4, 8],
258+
"graph_batch_sizes_init": True,
259+
},
260+
"refresh": True
261+
}
262+
init_ascend_config(test_vllm_config)
263+
264+
# torchair graph only works with deepseek.
265+
with self.assertRaises(NotImplementedError):
266+
test_vllm_config.additional_config = {
267+
"torchair_graph_config": {
268+
"enabled": True,
269+
},
270+
"refresh": True
271+
}
272+
init_ascend_config(test_vllm_config)
273+
# torchair graph should not be enabled with eager mode
274+
with self.assertRaises(RuntimeError):
275+
test_vllm_config.additional_config = {
276+
"torchair_graph_config": {
277+
"enabled": True,
278+
},
279+
"refresh": True
280+
}
281+
init_ascend_config(test_vllm_config)
282+
283+
# use_cached_graph should not be enabled without torchair graph mode
284+
with self.assertRaises(RuntimeError):
285+
test_vllm_config.additional_config = {
286+
"torchair_graph_config": {
287+
"enabled": False,
288+
"use_cached_graph": True,
289+
},
290+
"refresh": True
291+
}
292+
init_ascend_config(test_vllm_config)
293+
294+
# graph_batch_sizes_init should not be enabled without torchair graph mode
295+
with self.assertRaises(RuntimeError):
296+
test_vllm_config.additional_config = {
297+
"torchair_graph_config": {
298+
"enabled": False,
299+
"graph_batch_sizes_init": True,
300+
},
301+
"refresh": True
302+
}
303+
init_ascend_config(test_vllm_config)
304+
305+
# enable_multistream_mla should not be enabled without torchair graph mode
306+
with self.assertRaises(RuntimeError):
307+
test_vllm_config.additional_config = {
308+
"torchair_graph_config": {
309+
"enabled": False,
310+
"enable_multistream_mla": True,
311+
},
312+
"refresh": True
313+
}
314+
init_ascend_config(test_vllm_config)
315+
316+
# enable_multistream_moe should not be enabled without torchair graph mode
317+
with self.assertRaises(RuntimeError):
318+
test_vllm_config.additional_config = {
319+
"torchair_graph_config": {
320+
"enabled": False,
321+
"enable_multistream_moe": True,
322+
},
323+
"refresh": True
324+
}
325+
init_ascend_config(test_vllm_config)

vllm_ascend/ascend_config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,31 @@ def __init__(self, torchair_graph_config):
7676
raise ValueError(
7777
"graph_batch_sizes_init is only valid when graph_batch_sizes is empty"
7878
)
79+
if not self.enabled:
80+
if self.use_cached_graph:
81+
raise RuntimeError(
82+
"use_cached_graph is valid only when Torchair graph mode is enabled"
83+
)
84+
if self.graph_batch_sizes:
85+
raise RuntimeError(
86+
"graph_batch_sizes is valid only when Torchair graph mode is enabled"
87+
)
88+
if self.graph_batch_sizes_init:
89+
raise RuntimeError(
90+
"graph_batch_sizes_init is valid only when Torchair graph mode is enabled"
91+
)
92+
if self.enable_multistream_mla:
93+
raise RuntimeError(
94+
"enable_multistream_mla is valid only when Torchair graph mode is enabled"
95+
)
96+
if self.enable_multistream_moe:
97+
raise RuntimeError(
98+
"enable_multistream_moe is valid only when Torchair graph mode is enabled"
99+
)
100+
if self.enable_kv_nz:
101+
raise RuntimeError(
102+
"enable_kv_nz is valid only when Torchair graph mode is enabled"
103+
)
79104

80105

81106
class AscendSchedulerConfig:

vllm_ascend/models/deepseek_v2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,8 @@ def __init__(
313313
ascend_config = get_ascend_config()
314314
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
315315
self.enable_multistream_moe = \
316-
ascend_config.torchair_graph_config.enable_multistream_moe
316+
ascend_config.torchair_graph_config.enable_multistream_moe and \
317+
self.torchair_graph_enabled
317318

318319
self.gate = ReplicatedLinear(config.hidden_size,
319320
config.n_routed_experts,

vllm_ascend/ops/fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1253,7 +1253,8 @@ def __init__(
12531253

12541254
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
12551255
self.enable_multistream_moe = \
1256-
ascend_config.torchair_graph_config.enable_multistream_moe
1256+
ascend_config.torchair_graph_config.enable_multistream_moe and \
1257+
self.torchair_graph_enabled
12571258

12581259
if self.scoring_func != "softmax" and not self.use_grouped_topk:
12591260
raise ValueError("Only softmax scoring function is supported for "

0 commit comments

Comments
 (0)