diff --git a/docs/source/tutorials/single_npu_audio.md b/docs/source/tutorials/single_npu_audio.md index 06e093c571..734e6dc708 100644 --- a/docs/source/tutorials/single_npu_audio.md +++ b/docs/source/tutorials/single_npu_audio.md @@ -90,8 +90,7 @@ def main(audio_count: int): llm = LLM(model="Qwen/Qwen2-Audio-7B-Instruct", max_model_len=4096, max_num_seqs=5, - limit_mm_per_prompt={"audio": audio_count}, - enforce_eager=True) + limit_mm_per_prompt={"audio": audio_count}) inputs = prepare_inputs(audio_count) diff --git a/docs/source/tutorials/single_npu_multimodal.md b/docs/source/tutorials/single_npu_multimodal.md index 8acd9cf9c1..8c19651dc6 100644 --- a/docs/source/tutorials/single_npu_multimodal.md +++ b/docs/source/tutorials/single_npu_multimodal.md @@ -57,7 +57,6 @@ llm = LLM( model=MODEL_PATH, max_model_len=16384, limit_mm_per_prompt={"image": 10}, - enforce_eager=True, ) sampling_params = SamplingParams( @@ -146,8 +145,7 @@ docker run --rm \ vllm serve Qwen/Qwen2.5-VL-7B-Instruct \ --dtype bfloat16 \ --max_model_len 16384 \ ---max-num-batched-tokens 16384 \ ---enforce-eager +--max-num-batched-tokens 16384 ``` :::{note} diff --git a/requirements-dev.txt b/requirements-dev.txt index cbd851ed37..4f36cd70d9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -15,3 +15,5 @@ regex sentence_transformers ray>=2.47.1 protobuf==4.25.6 +librosa +soundfile diff --git a/tests/e2e/singlecard/test_offline_inference.py b/tests/e2e/singlecard/test_offline_inference.py index e5b9364de1..c6c68e55e8 100644 --- a/tests/e2e/singlecard/test_offline_inference.py +++ b/tests/e2e/singlecard/test_offline_inference.py @@ -27,6 +27,7 @@ import vllm # noqa: F401 from modelscope import snapshot_download # type: ignore[import-untyped] from vllm import SamplingParams +from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset import vllm_ascend # noqa: F401 @@ -36,12 +37,18 @@ "Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen3-0.6B-Base", ] -MULTIMODALITY_MODELS = ["Qwen/Qwen2.5-VL-3B-Instruct"] +MULTIMODALITY_VL_MODELS = ["Qwen/Qwen2.5-VL-3B-Instruct"] +MULTIMODALITY_AUDIO_MODELS = ["Qwen/Qwen2-Audio-7B-Instruct"] QUANTIZATION_MODELS = [ "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8", ] os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" +AUDIO_ASSETS = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] +AUDIO_PROMPT_TEMPLATES = { + 1: "What is recited in the audio?", + 2: "What sport and what nursery rhyme are referenced?" +} @pytest.mark.parametrize("model", MODELS) @@ -84,8 +91,8 @@ def test_quantization_models(model: str, max_tokens: int) -> None: vllm_model.generate_greedy(example_prompts, max_tokens) -@pytest.mark.parametrize("model", MULTIMODALITY_MODELS) -def test_multimodal(model, prompt_template, vllm_runner): +@pytest.mark.parametrize("model", MULTIMODALITY_VL_MODELS) +def test_multimodal_vl(model, prompt_template, vllm_runner): image = ImageAsset("cherry_blossom") \ .pil_image.convert("RGB") img_questions = [ @@ -108,6 +115,45 @@ def test_multimodal(model, prompt_template, vllm_runner): max_tokens=64) +def prepare_audio_inputs(audio_count: int): + audio_prompt = "".join([ + f"Audio {idx+1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n" + for idx in range(audio_count) + ]) + question = AUDIO_PROMPT_TEMPLATES[audio_count] + prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + f"{audio_prompt}{question}<|im_end|>\n" + "<|im_start|>assistant\n") + mm_data = { + "audio": + [asset.audio_and_sample_rate for asset in AUDIO_ASSETS[:audio_count]] + } + inputs = {"prompt": prompt, "multi_modal_data": mm_data} + return inputs + + +@pytest.mark.parametrize("model", MULTIMODALITY_AUDIO_MODELS) +@pytest.mark.parametrize("audio_count", [2]) +@pytest.mark.parametrize("max_tokens", [10]) +def test_multimodal_audio(model: str, audio_count: int, + max_tokens: int) -> None: + inputs = prepare_audio_inputs(audio_count) + + sampling_params = SamplingParams(temperature=0.2, + max_tokens=max_tokens, + stop_token_ids=None) + + with VllmRunner(model, + max_model_len=4096, + max_num_seqs=5, + enforce_eager=False, + dtype="bfloat16", + limit_mm_per_prompt={"audio": audio_count}, + gpu_memory_utilization=0.9) as vllm_model: + vllm_model.generate(inputs, sampling_params=sampling_params) + + @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION": "1"}) def test_models_topk() -> None: example_prompts = [ diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 771684952c..94d37a023a 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -260,6 +260,61 @@ def test_vllm_version_is(self): hits = utils.vllm_version_is.cache_info().hits self.assertEqual(hits, 1) + def test_get_max_hidden_layers(self): + from transformers import PretrainedConfig + + class SimpleConfig(PretrainedConfig): + + def __init__(self, num_hidden_layers=12): + self.num_hidden_layers = num_hidden_layers + + def to_dict(self): + return {"num_hidden_layers": self.num_hidden_layers} + + self.assertEqual(utils.get_max_hidden_layers(SimpleConfig()), 12) + self.assertEqual(utils.get_max_hidden_layers(SimpleConfig(24)), 24) + + class NestedConfig(PretrainedConfig): + + def to_dict(self): + return { + "model": { + "encoder": { + "num_hidden_layers": 8 + }, + "decoder": { + "num_hidden_layers": 12 + } + }, + "other_setting": True + } + + self.assertEqual(utils.get_max_hidden_layers(NestedConfig()), 12) + + class MultiValueConfig(PretrainedConfig): + + def to_dict(self): + return { + "num_hidden_layers": 6, + "submodule": { + "num_hidden_layers": 18, + "subsub": { + "num_hidden_layers": 9 + } + } + } + + self.assertEqual(utils.get_max_hidden_layers(MultiValueConfig()), 18) + + class NoLayerConfig(PretrainedConfig): + + def to_dict(self): + return {"attention_heads": 8} + + with self.assertRaises(ValueError) as context: + utils.get_max_hidden_layers(NoLayerConfig()) + self.assertIn("num_hidden_layers", str(context.exception)) + def test_update_aclgraph_sizes(self): # max_num_batch_sizes < len(original_sizes) test_compilation_config = CompilationConfig( diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index f859b843c9..3e1785f311 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -288,6 +288,24 @@ def vllm_version_is(target_vllm_version: str): "format of x.y.z.") +def get_max_hidden_layers(hf_config) -> int: + cfg_dict = hf_config.to_dict() + layer_counts = [] + + def _rec_find(d): + if isinstance(d, dict): + for k, v in d.items(): + if k == "num_hidden_layers" and isinstance(v, int): + layer_counts.append(v) + else: + _rec_find(v) + + _rec_find(cfg_dict) + if not layer_counts: + raise ValueError("Not found num_hidden_layers in model config.") + return max(layer_counts) + + def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: """Update ACL graph capture sizes based on hardware limitations""" # Store original configuration and temporarily clear it @@ -296,7 +314,11 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: compilation_config.cudagraph_capture_sizes, None # Calculate parallel configuration factor - num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers + hf_config = vllm_config.model_config.hf_config + if hasattr(hf_config, 'num_hidden_layers'): + num_hidden_layers = hf_config.num_hidden_layers + else: + num_hidden_layers = get_max_hidden_layers(hf_config) parallel_config = vllm_config.parallel_config # TODO: Find out whether we need to take into account the pp_size