diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index e9993fd84061..34465b7d2708 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -71,6 +71,26 @@ def test_multiple_sampling_params(llm: LLM): assert len(PROMPTS) == len(outputs) +def test_multiple_priority(llm: LLM): + # Generate works when priority is None + outputs = llm.generate(PROMPTS, sampling_params=None, priority=None) + assert len(PROMPTS) == len(outputs) + + # Generate works when length of priority is same as the len(PROMPTS) + outputs = llm.generate(PROMPTS, sampling_params=None, priority=[0] * len(PROMPTS)) + assert len(PROMPTS) == len(outputs) + + # Exception raised, if the length of priority does not match the length of prompts + with pytest.raises(ValueError): + outputs = llm.generate( + PROMPTS, sampling_params=None, priority=[0] * (len(PROMPTS) - 1) + ) + + # Exception raised, if the priority list is empty + with pytest.raises(ValueError): + outputs = llm.generate(PROMPTS, sampling_params=None, priority=[]) + + def test_max_model_len(): max_model_len = 20 llm = LLM( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c15b70a06809..6a9cf554a718 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1564,6 +1564,12 @@ def _validate_and_add_requests( raise ValueError( "The lengths of prompts and lora_request must be the same." ) + if priority is not None and len(priority) != num_requests: + raise ValueError( + "The lengths of prompts " + f"({num_requests}) and priority ({len(priority)}) " + "must be the same." + ) for sp in params if isinstance(params, Sequence) else (params,): if isinstance(sp, SamplingParams):