diff --git a/tests/e2e/singlecard/test_ascend_scheduler.py b/tests/e2e/singlecard/test_ascend_scheduler.py index f7830dd3dd..2aab523689 100644 --- a/tests/e2e/singlecard/test_ascend_scheduler.py +++ b/tests/e2e/singlecard/test_ascend_scheduler.py @@ -1,743 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - import pytest -import torch -from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, - SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange -from vllm.sampling_params import SamplingParams -from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) -from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.request import Request, RequestStatus -from vllm.v1.structured_output import StructuredOutputManager from tests.e2e.conftest import VllmRunner from tests.e2e.model_utils import check_outputs_equal -from vllm_ascend.core.scheduler import AscendScheduler -from vllm_ascend.utils import vllm_version_is -EOS_TOKEN_ID = 50256 MODEL = "Qwen/Qwen3-0.6B" -def create_scheduler( - model: str = MODEL, - max_num_seqs: int = 16, - max_num_batched_tokens: int = 8192, - enable_prefix_caching: Optional[bool] = None, - long_prefill_token_threshold: int = 0, - disable_chunked_mm_input: bool = False, - use_kv_connector: bool = False, - num_blocks: int = 10000, - block_size: int = 16, - max_model_len: Optional[int] = None, - num_speculative_tokens: Optional[int] = None, - enable_chunked_prefill: bool = False, -) -> AscendScheduler: - '''Create scheduler under test. - - Args: - model: model under test - max_num_seqs: max sequences to schedule - max_num_batch_tokens: max num tokens to batch - enable_prefix_caching: optionally force APC config - (True/False) or use default - (None) - - Returns: - {class}`Scheduler` instance - ''' - if max_model_len is None: - max_model_len = max_num_batched_tokens - scheduler_config = SchedulerConfig( - max_num_seqs=max_num_seqs, - max_num_batched_tokens=max_num_batched_tokens, - max_model_len=max_model_len, - long_prefill_token_threshold=long_prefill_token_threshold, - disable_chunked_mm_input=disable_chunked_mm_input, - enable_chunked_prefill=enable_chunked_prefill, - ) - model_config = ModelConfig( - model=model, - task="auto", - tokenizer=model, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=42, - ) - # Cache config, optionally force APC - kwargs_cache = ({} if enable_prefix_caching is None else { - 'enable_prefix_caching': enable_prefix_caching - }) - cache_config = CacheConfig( - block_size=block_size, - gpu_memory_utilization=0.9, - swap_space=0, - cache_dtype="auto", - **kwargs_cache, - ) - kv_transfer_config = KVTransferConfig( - kv_connector="SharedStorageConnector", - kv_role="kv_both", - kv_connector_extra_config={"shared_storage_path": "local_storage"}, - ) if use_kv_connector else None - - speculative_config: Optional[SpeculativeConfig] = None - if num_speculative_tokens is not None: - speculative_config = SpeculativeConfig( - model="ngram", num_speculative_tokens=num_speculative_tokens) - - vllm_config = VllmConfig( - scheduler_config=scheduler_config, - model_config=model_config, - cache_config=cache_config, - kv_transfer_config=kv_transfer_config, - speculative_config=speculative_config, - ) - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, # A large number of blocks to hold all requests - kv_cache_tensors=[], - kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) - ], - ) - cache_config.num_gpu_blocks = num_blocks - return AscendScheduler( - vllm_config=vllm_config, - kv_cache_config=kv_cache_config, - log_stats=True, - structured_output_manager=StructuredOutputManager(vllm_config), - ) - - -def create_requests(num_requests: int, - num_tokens: int = 10, - mm_positions: Optional[list[PlaceholderRange]] = None, - max_tokens: int = 16, - stop_token_ids: Optional[list[int]] = None, - prompt_logprobs: Optional[int] = None): - sampling_params = SamplingParams(ignore_eos=False, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - prompt_logprobs=prompt_logprobs) - requests = [] - for i in range(num_requests): - if mm_positions is not None: - mm_position = mm_positions[i] - mm_inputs = [MultiModalKwargs({})] * len(mm_position) - else: - mm_position = None - mm_inputs = None - request = Request( - request_id=f"{i}", - prompt_token_ids=[i] * num_tokens, - sampling_params=sampling_params, - multi_modal_inputs=mm_inputs, - multi_modal_placeholders=mm_position, - multi_modal_hashes=None, - eos_token_id=EOS_TOKEN_ID, - pooling_params=None, - ) - requests.append(request) - return requests - - -def test_add_requests(): - scheduler = create_scheduler() - requests = create_requests(num_requests=10) - - for i, request in enumerate(requests): - scheduler.add_request(request) - assert request.request_id in scheduler.requests - assert len(scheduler.waiting) == i + 1 - - -def test_finish_request(): - scheduler = create_scheduler() - requests = create_requests(num_requests=10) - for request in requests: - scheduler.add_request(request) - - for i, request in enumerate(requests): - scheduler.finish_requests(request.request_id, - RequestStatus.FINISHED_ABORTED) - assert request.request_id not in scheduler.requests - assert len(scheduler.waiting) == 9 - i - - -def test_get_num_unfinished_requests(): - scheduler = create_scheduler() - requests = create_requests(num_requests=10) - for request in requests: - scheduler.add_request(request) - - for i, request in enumerate(requests): - scheduler.finish_requests(request.request_id, - RequestStatus.FINISHED_STOPPED) - assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1 - - -@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [ - (None, None), - (True, 5), -]) -def test_schedule(enable_prefix_caching: Optional[bool], - prompt_logprobs: Optional[int]): - '''Test scheduling. - Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs - ''' - scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching) - requests = create_requests(num_requests=10, - prompt_logprobs=prompt_logprobs) - for request in requests: - scheduler.add_request(request) - - # Test initial scheduling - output = scheduler.schedule() - assert len(output.scheduled_new_reqs) == len(requests) - assert output.scheduled_cached_reqs.num_reqs == 0 - assert len(output.finished_req_ids) == 0 - # Verify all requests are scheduled. - for req_id, num_tokens in output.num_scheduled_tokens.items(): - assert num_tokens == len(requests[int(req_id)].prompt_token_ids) - - # Verify requests moved from waiting to running - assert len(scheduler.waiting) == 0 - assert len(scheduler.running) == len(requests) - for i, request in enumerate(requests): - assert scheduler.running[i] == request - - -@pytest.mark.parametrize("enable_prefix_caching", [True, False]) -def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): - """Test scheduling behavior with concurrent partial requests. - - This test verifies that: there are multiple long prefill requests in the - RUNNING state, and we can schedule them together. - - """ - scheduler = create_scheduler( - model="facebook/opt-125m", - max_num_batched_tokens=1024, - long_prefill_token_threshold=400, - enable_prefix_caching=enable_prefix_caching, - enable_chunked_prefill=True, - ) - requests = create_requests( - num_requests=3, - num_tokens=800, - ) - for request in requests: - scheduler.add_request(request) - - output = scheduler.schedule() - assert len(output.scheduled_new_reqs) == 3 - assert output.scheduled_cached_reqs.num_reqs == 0 - assert len(output.finished_req_ids) == 0 - - # The first request is scheduled partially - 400. - assert output.num_scheduled_tokens[requests[0].request_id] == 400 - # The second request is scheduled partially - 400. - assert output.num_scheduled_tokens[requests[1].request_id] == 400 - # The third request is also scheduled partially - 1024 - 400 - 400 = 224. - assert output.num_scheduled_tokens[requests[2].request_id] == 224 - req_to_index = { - request.request_id: i - for i, request in enumerate(requests) - } - model_runner_output = ModelRunnerOutput( - req_ids=[request.request_id for request in requests], - req_id_to_index=req_to_index, - sampled_token_ids=[[] for _ in range(len(requests))], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - scheduler.update_from_output(output, model_runner_output) - - # Schedule the next step. All three requests are running. - # Processed the remaining prefills of the first and second requests. - output1 = scheduler.schedule() - assert len(scheduler.running) == 3 - assert len(output1.scheduled_new_reqs) == 0 - assert output1.scheduled_cached_reqs.num_reqs == 3 - assert len(output1.finished_req_ids) == 0 - assert output1.num_scheduled_tokens[requests[0].request_id] == 400 - assert output1.num_scheduled_tokens[requests[1].request_id] == 400 - assert output1.num_scheduled_tokens[requests[2].request_id] == 224 - - # Schedule the third step. All three requests are running. - # First and second requests are in the decode stage. - # All the remaining tokens in the third request are processed. - model_runner_output = ModelRunnerOutput( - req_ids=[request.request_id for request in requests], - req_id_to_index=req_to_index, - sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - - scheduler.update_from_output(output1, model_runner_output) - output2 = scheduler.schedule() - assert len(scheduler.running) == 3 - assert len(output2.scheduled_new_reqs) == 0 - assert output2.scheduled_cached_reqs.num_reqs == 3 - assert len(output2.finished_req_ids) == 0 - assert output2.num_scheduled_tokens[requests[0].request_id] == 1 - assert output2.num_scheduled_tokens[requests[1].request_id] == 1 - assert output2.num_scheduled_tokens[ - requests[2].request_id] == 800 - 224 - 224 - - -def test_stop_via_update_from_output(): - """Test stopping behavior through update_from_output""" - scheduler = create_scheduler(num_speculative_tokens=1) - - # Test case 1: Stop on EOS token - requests = create_requests(num_requests=2, max_tokens=10) - for req in requests: - req.num_computed_tokens = req.num_tokens - scheduler.requests[req.request_id] = req - scheduler.running.append(req) - if not vllm_version_is("0.9.2"): - req.status = RequestStatus.RUNNING - - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 1, - requests[1].request_id: 2 - }, - total_num_scheduled_tokens=3, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [], - requests[1].request_id: [10] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) - - model_output = ModelRunnerOutput( - req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[EOS_TOKEN_ID], - [10, - 11]], # First request hits EOS, second continues - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - - scheduler.update_from_output(scheduler_output, model_output) - - # Verify first request stopped, second continues - assert len(scheduler.running) == 1 - assert scheduler.running[0].request_id == requests[1].request_id - assert requests[0].status == RequestStatus.FINISHED_STOPPED - assert requests[0].request_id in scheduler.finished_req_ids - assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID] - assert list(requests[1].output_token_ids) == [10, 11] - - # Test case 2: Stop on custom stop token - scheduler = create_scheduler(num_speculative_tokens=2) - requests = create_requests(num_requests=2, - max_tokens=10, - stop_token_ids=[42, 43]) - for req in requests: - req.num_computed_tokens = req.num_tokens - scheduler.requests[req.request_id] = req - scheduler.running.append(req) - if not vllm_version_is("0.9.2"): - req.status = RequestStatus.RUNNING - - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 2 - }, - total_num_scheduled_tokens=5, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [10, 42], - requests[1].request_id: [13] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) - - model_output = ModelRunnerOutput( - req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[10, 42, 12], - [13, 14]], # First request hits stop token - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - - scheduler.update_from_output(scheduler_output, model_output) - - # Verify first request stopped on custom token - assert len(scheduler.running) == 1 - assert scheduler.running[0].request_id == requests[1].request_id - assert requests[0].status == RequestStatus.FINISHED_STOPPED - assert requests[0].stop_reason == 42 - assert requests[0].request_id in scheduler.finished_req_ids - assert list(requests[0].output_token_ids) == [10, 42] - assert list(requests[1].output_token_ids) == [13, 14] - - # Test case 3: Stop on max tokens - scheduler = create_scheduler(num_speculative_tokens=2) - requests = create_requests(num_requests=2, max_tokens=2) - for req in requests: - req.num_computed_tokens = req.num_tokens - scheduler.requests[req.request_id] = req - scheduler.running.append(req) - if not vllm_version_is("0.9.2"): - req.status = RequestStatus.RUNNING - - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 1 - }, - total_num_scheduled_tokens=4, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [10, 11], - requests[1].request_id: [] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) - - model_output = ModelRunnerOutput( - req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[10, 11, 12], - [13]], # First request exceeds max_tokens - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - - scheduler.update_from_output(scheduler_output, model_output) - - # Verify first request stopped due to length - assert len(scheduler.running) == 1 - assert scheduler.running[0].request_id == requests[1].request_id - assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED - assert requests[0].request_id in scheduler.finished_req_ids - assert list(requests[0].output_token_ids) == [10, 11 - ] # Truncated to max_tokens - assert list(requests[1].output_token_ids) == [13] - - # Test case 4: Ignore EOS flag - scheduler = create_scheduler(num_speculative_tokens=2) - requests = create_requests(num_requests=1, max_tokens=10) - requests[0].sampling_params.ignore_eos = True - requests[0].num_computed_tokens = requests[0].num_tokens - scheduler.requests[requests[0].request_id] = requests[0] - scheduler.running.append(requests[0]) - - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={requests[0].request_id: 3}, - total_num_scheduled_tokens=3, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [EOS_TOKEN_ID, 10] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) - - model_output = ModelRunnerOutput( - req_ids=[requests[0].request_id], - req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - - scheduler.update_from_output(scheduler_output, model_output) - - # Verify request continues past EOS - assert len(scheduler.running) == 1 - assert not requests[0].is_finished() - assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11] - - -@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [ - (None, None), - (True, 5), -]) -def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], - prompt_logprobs: Optional[int]): - scheduler = create_scheduler( - max_num_batched_tokens=1024, - max_num_seqs=2, - enable_prefix_caching=enable_prefix_caching, - enable_chunked_prefill=True, - ) - requests = create_requests( - num_requests=2, - num_tokens=512, - prompt_logprobs=prompt_logprobs, - ) - - # Schedule the first request. - scheduler.add_request(requests[0]) - scheduler_output0 = scheduler.schedule() - assert len(scheduler_output0.scheduled_new_reqs) == 1 - assert scheduler_output0.num_scheduled_tokens[ - requests[0].request_id] == 512 - - # The first request is still running, so only schedule the second request. - scheduler.add_request(requests[1]) - scheduler_output1 = scheduler.schedule() - assert len(scheduler_output1.scheduled_new_reqs) == 1 - assert scheduler_output1.num_scheduled_tokens[ - requests[1].request_id] == 512 - - # Model output of the first request. - model_runner_output = ModelRunnerOutput( - req_ids=[requests[0].request_id], - req_id_to_index={requests[0].request_id: 0}, - sampled_token_ids=[[0]], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - - scheduler.update_from_output(scheduler_output0, model_runner_output) - - # Schedule the next step. - # The first request can be scheduled again while the second - # request is still running. - scheduler_output2 = scheduler.schedule() - assert scheduler_output2.num_scheduled_tokens[requests[0].request_id] == 1 - - # Model output of the second request. - model_runner_output = ModelRunnerOutput( - req_ids=[requests[1].request_id], - req_id_to_index={requests[1].request_id: 0}, - sampled_token_ids=[[0]], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - - scheduler.update_from_output(scheduler_output1, model_runner_output) - - -# Note - these test cases mirror some of those in test_rejection_sampler.py -@pytest.mark.parametrize( - "spec_tokens,output_tokens,expected", - [ - ([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match - ([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch - ([[1, 2], [3]], [[1, 2, 5], [3, 4]], - (2, 3, 3, [2, 1])), # multiple sequences - ([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence - ([[]], [[5]], (0, 0, 0, [0])), # empty sequence - ([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]], - (2, 6, 3, [2, 1, 0])), # multiple mismatches - ]) -def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): - """Test scheduling behavior with speculative decoding. - - This test verifies that: - 1. Speculated tokens get scheduled correctly - 2. Spec decoding stats properly count number of draft and accepted tokens - """ - num_spec_tokens = max(1, max(len(t) for t in spec_tokens)) - scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens) - requests = create_requests(num_requests=len(spec_tokens), num_tokens=1) - req_ids = [] - req_to_index = {} - for i, request in enumerate(requests): - scheduler.add_request(request) - req_ids.append(request.request_id) - req_to_index[request.request_id] = i - - # Schedule a decode, which will also draft speculative tokens - output = scheduler.schedule() - assert len(output.scheduled_new_reqs) == len(requests) - assert output.total_num_scheduled_tokens == len(requests) - for i in range(len(requests)): - req_id = requests[i].request_id - assert output.num_scheduled_tokens[req_id] == 1 - assert req_id not in output.scheduled_spec_decode_tokens - - model_runner_output = ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=req_to_index, - sampled_token_ids=[[0] for _ in range(len(requests))], - spec_token_ids=spec_tokens, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - - engine_core_outputs = scheduler.update_from_output(output, - model_runner_output) - - for i in range(len(requests)): - running_req = scheduler.running[i] - # The prompt token - assert running_req.num_computed_tokens == 1 - # The prompt token and the sampled token - assert running_req.num_tokens == 2 - # The prompt token, the sampled token, and the speculated tokens - assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i]) - - # No draft or accepted tokens counted yet - assert not engine_core_outputs or ( - engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None) - - # Schedule the speculated tokens for validation - output = scheduler.schedule() - assert len(output.scheduled_new_reqs) == 0 - # The sampled token and speculated tokens - assert output.total_num_scheduled_tokens == \ - len(requests) + sum(len(ids) for ids in spec_tokens) - for i in range(len(requests)): - req_id = requests[i].request_id - assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i]) - if spec_tokens[i]: - assert len(output.scheduled_spec_decode_tokens[req_id]) == \ - len(spec_tokens[i]) - else: - assert req_id not in output.scheduled_spec_decode_tokens - - model_runner_output = ModelRunnerOutput(req_ids=req_ids, - req_id_to_index=req_to_index, - sampled_token_ids=output_tokens, - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - - engine_core_outputs = scheduler.update_from_output(output, - model_runner_output) - - scheduler_stats = engine_core_outputs[0].scheduler_stats \ - if engine_core_outputs else None - if expected[0] == 0: - assert scheduler_stats.spec_decoding_stats is None # type: ignore - else: - assert scheduler_stats.spec_decoding_stats is not None # type: ignore - stats = scheduler_stats.spec_decoding_stats # type: ignore - assert stats.num_drafts == expected[0] - assert stats.num_draft_tokens == expected[1] - assert stats.num_accepted_tokens == expected[2] - assert stats.num_accepted_tokens_per_pos == expected[3] - - -def make_output(scheduler: AscendScheduler): - return ModelRunnerOutput( - req_ids=[req.request_id for req in scheduler.running], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(scheduler.running) - }, - sampled_token_ids=[[1000]] * len(scheduler.running), - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[]) - - -def assert_scheduler_empty(scheduler: AscendScheduler): - """Confirm the scheduler is "empty" - i.e. no leaks.""" - # Scheduler Metadata. - assert len(scheduler.requests) == 0 - assert len(scheduler.waiting) == 0 - assert len(scheduler.running) == 0 - assert len(scheduler.finished_req_ids) == 0 - - # EncoderCacheManager. - assert len(scheduler.encoder_cache_manager.freed) == 0 - assert len(scheduler.encoder_cache_manager.cached) == 0 - - # KVCache Manager. - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block) == 0 - assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 - num_free_blocks = ( - scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) - assert num_free_blocks == ( - scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) - - # NOTE(rob): just the ref count on blocks will be 0. The hash - # value, etc will remain since we lazily evict for prefix cache. - for block in scheduler.kv_cache_manager.block_pool.blocks: - assert block.ref_cnt == 0 - - -def test_memory_leak(): - """Test that we do not have a memory leak.""" - - scheduler = create_scheduler(enable_prefix_caching=True) - - NUM_REQUESTS = 5 - NUM_TOKENS = 10 - MAX_TOKENS = 10 - requests = create_requests(num_requests=NUM_REQUESTS, - num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) - - # Add each request. - for request in requests: - scheduler.add_request(request) - scheduler_output = scheduler.schedule() - model_runner_output = make_output(scheduler) - scheduler.update_from_output(scheduler_output, model_runner_output) - - # Iterate until done. - while True: - scheduler_output = scheduler.schedule() - if len(scheduler.running) == 0: - break - model_runner_output = make_output(scheduler) - scheduler.update_from_output(scheduler_output, model_runner_output) - - # Confirm no memory leak. - assert_scheduler_empty(scheduler) - - def test_concurrent_partial_prefill(): with VllmRunner(MODEL, additional_config={ diff --git a/tests/ut/core/test_scheduler.py b/tests/ut/core/test_scheduler.py new file mode 100644 index 0000000000..74aa8b84da --- /dev/null +++ b/tests/ut/core/test_scheduler.py @@ -0,0 +1,718 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Dict, List, Optional, Tuple +from unittest.mock import MagicMock, patch + +import torch +from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, + SchedulerConfig, SpeculativeConfig, VllmConfig) +from vllm.multimodal.inputs import PlaceholderRange +from vllm.sampling_params import SamplingParams +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus +from vllm.v1.structured_output import StructuredOutputManager + +from tests.ut.base import TestBase +from vllm_ascend.core.scheduler import AscendScheduler +from vllm_ascend.utils import vllm_version_is + +EOS_TOKEN_ID = 50256 +MODEL = "Qwen3-0.6B" +ENABLE_PREFIX_CACHING = None +PROMPT_LOGPROBS = None +ENABLE_CHUNKED_PREFILL = False +MAX_NUM_BATCHED_TOKENS = 10000 +LONG_PREFILL_TOKEN_THRESHOLD = 0 +NUM_SPECULATIVE_TOKENS = None +MAX_NUM_SEQS = 16 + + +def create_requests( + num_requests: int, + num_tokens: int = 10, + mm_positions: Optional[list[PlaceholderRange]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[list[int]] = None, +): + prompt_logprobs = PROMPT_LOGPROBS + sampling_params = SamplingParams(ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs) + requests = [] + for i in range(num_requests): + mm_position = None + mm_inputs = None + request = Request( + request_id=f"{i}", + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + pooling_params=None, + ) + requests.append(request) + return requests + + +def make_output(scheduler): + return ModelRunnerOutput( + req_ids=[req.request_id for req in scheduler.running], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(scheduler.running) + }, + sampled_token_ids=[[1000]] * len(scheduler.running), + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + +class TestAscendScheduler(TestBase): + + @patch("vllm.config.ModelConfig.__post_init__", MagicMock()) + @patch("vllm.config.VllmConfig.__post_init__", MagicMock()) + @patch('vllm.v1.core.sched.scheduler.compute_encoder_budget') + def create_scheduler(self, mock_compute_encoder_budget): + mock_compute_encoder_budget.return_value = [10, 20] + use_kv_connector = False + block_size = 16 + + scheduler_config = SchedulerConfig( + max_num_seqs=16, + max_model_len=MAX_NUM_BATCHED_TOKENS, + long_prefill_token_threshold=LONG_PREFILL_TOKEN_THRESHOLD, + disable_chunked_mm_input=False, + enable_chunked_prefill=ENABLE_CHUNKED_PREFILL, + max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, + ) + + scheduler_config.max_num_encoder_input_tokens = 10000 + scheduler_config.encoder_cache_size = 10000 + scheduler_config.chunked_prefill_enabled = False + + model_config = ModelConfig( + model=MODEL, + task="auto", + tokenizer=MODEL, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + max_model_len=MAX_NUM_BATCHED_TOKENS, + ) + model_config.pooler_config = MagicMock() + model_config.multimodal_config = MagicMock() + # Cache config, optionally force APC + kwargs_cache: Dict[str, + Any] = ({} if ENABLE_PREFIX_CACHING is None else { + 'enable_prefix_caching': + ENABLE_PREFIX_CACHING + }) + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + **kwargs_cache, + ) + + kv_transfer_config = KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) if use_kv_connector else None + + speculative_config: Optional[SpeculativeConfig] = None + if NUM_SPECULATIVE_TOKENS is not None: + speculative_config = SpeculativeConfig( + model="ngram", num_speculative_tokens=NUM_SPECULATIVE_TOKENS) + + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + speculative_config=speculative_config, + ) + + kv_cache_config = KVCacheConfig( + num_blocks=10000, # A large number of blocks to hold all requests + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, + torch.float32, False)) + ], + ) + cache_config.num_gpu_blocks = 10000 + + scheduler = AscendScheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=MagicMock(spec=StructuredOutputManager), + ) + + should_advance = MagicMock() + should_advance.return_value = False + scheduler.structured_output_manager.should_advance = should_advance + + return scheduler + + def test_add_requests(self): + scheduler = self.create_scheduler() + requests = create_requests(num_requests=10) + + for i, request in enumerate(requests): + scheduler.add_request(request) + self.assertIn(request.request_id, scheduler.requests) + self.assertEqual(len(scheduler.waiting), i + 1) + + def test_finish_request(self): + scheduler = self.create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + for i, request in enumerate(requests): + scheduler.finish_requests(request.request_id, + RequestStatus.FINISHED_ABORTED) + self.assertNotIn(request.request_id, scheduler.requests) + self.assertEqual(len(scheduler.waiting), 9 - i) + + def test_get_num_unfinished_requests(self): + scheduler = self.create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + for i, request in enumerate(requests): + scheduler.finish_requests(request.request_id, + RequestStatus.FINISHED_STOPPED) + self.assertEqual(scheduler.get_num_unfinished_requests(), + len(requests) - i - 1) + + def test_schedule(self): + '''Test scheduling. + Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs + ''' + scheduler = self.create_scheduler() + scheduler.scheduler_config.chunked_prefill_enabled = False + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + # Test initial scheduling + output = scheduler.schedule() + self.assertEqual(len(output.scheduled_new_reqs), len(requests)) + self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0) + self.assertEqual(len(output.finished_req_ids), 0) + # Verify all requests are scheduled. + for req_id, num_tokens in output.num_scheduled_tokens.items(): + self.assertEqual(num_tokens, + len(requests[int(req_id)].prompt_token_ids)) + + # Verify requests moved from waiting to running + self.assertEqual(len(scheduler.waiting), 0) + self.assertEqual(len(scheduler.running), len(requests)) + for i, request in enumerate(requests): + self.assertEqual(scheduler.running[i], request) + + def test_schedule_enable_prefix_caching(self): + '''Test scheduling. + Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs + ''' + global ENABLE_PREFIX_CACHING + ENABLE_PREFIX_CACHING = True + global PROMPT_LOGPROBS + PROMPT_LOGPROBS = 5 + scheduler = self.create_scheduler() + scheduler.scheduler_config.chunked_prefill_enabled = False + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + # Test initial scheduling + output = scheduler.schedule() + self.assertEqual(len(output.scheduled_new_reqs), len(requests)) + self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0) + self.assertEqual(len(output.finished_req_ids), 0) + # Verify all requests are scheduled. + for req_id, num_tokens in output.num_scheduled_tokens.items(): + self.assertEqual(num_tokens, + len(requests[int(req_id)].prompt_token_ids)) + + # Verify requests moved from waiting to running + self.assertEqual(len(scheduler.waiting), 0) + self.assertEqual(len(scheduler.running), len(requests)) + for i, request in enumerate(requests): + self.assertEqual(scheduler.running[i], request) + + def test_stop_via_update_from_output(self): + """Test stopping behavior through update_from_output""" + global NUM_SPECULATIVE_TOKENS + NUM_SPECULATIVE_TOKENS = 1 + scheduler = self.create_scheduler() + + # Test case 1: Stop on EOS token + requests = create_requests(num_requests=2, max_tokens=10) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + if not vllm_version_is("0.9.2"): + req.status = RequestStatus.RUNNING + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 1, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [], + requests[1].request_id: [10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[EOS_TOKEN_ID], [10, 11] + ], # First request hits EOS, second continues + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped, second continues + self.assertEqual(len(scheduler.running), 1) + self.assertEqual(scheduler.running[0].request_id, + requests[1].request_id) + self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED) + self.assertIn(requests[0].request_id, scheduler.finished_req_ids) + self.assertEqual(list(requests[0].output_token_ids), [EOS_TOKEN_ID]) + self.assertEqual(list(requests[1].output_token_ids), [10, 11]) + + # Test case 2: Stop on custom stop token + NUM_SPECULATIVE_TOKENS = 2 + scheduler = self.create_scheduler() + requests = create_requests(num_requests=2, + max_tokens=10, + stop_token_ids=[42, 43]) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + if not vllm_version_is("0.9.2"): + req.status = RequestStatus.RUNNING + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=5, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: + [10, 42], + requests[1].request_id: [13] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[10, 42, 12], + [13, 14]], # First request hits stop token + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped on custom token + self.assertEqual(len(scheduler.running), 1) + self.assertEqual(scheduler.running[0].request_id, + requests[1].request_id) + self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED) + self.assertEqual(requests[0].stop_reason, 42) + self.assertIn(requests[0].request_id, scheduler.finished_req_ids) + self.assertEqual(list(requests[0].output_token_ids), [10, 42]) + self.assertEqual(list(requests[1].output_token_ids), [13, 14]) + + # Test case 3: Stop on max tokens + NUM_SPECULATIVE_TOKENS = 2 + scheduler = self.create_scheduler() + requests = create_requests(num_requests=2, max_tokens=2) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + if not vllm_version_is("0.9.2"): + req.status = RequestStatus.RUNNING + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 1 + }, + total_num_scheduled_tokens=4, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: + [10, 11], + requests[1].request_id: [] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[10, 11, 12], + [13]], # First request exceeds max_tokens + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped due to length + self.assertEqual(len(scheduler.running), 1) + self.assertEqual(scheduler.running[0].request_id, + requests[1].request_id) + self.assertEqual(requests[0].status, + RequestStatus.FINISHED_LENGTH_CAPPED) + self.assertIn(requests[0].request_id, scheduler.finished_req_ids) + self.assertEqual(list(requests[0].output_token_ids), [10, 11]) + self.assertEqual(list(requests[1].output_token_ids), [13]) + + # Test case 4: Ignore EOS flag + scheduler = self.create_scheduler() + requests = create_requests(num_requests=1, max_tokens=10) + requests[0].sampling_params.ignore_eos = True + requests[0].num_computed_tokens = requests[0].num_tokens + scheduler.requests[requests[0].request_id] = requests[0] + scheduler.running.append(requests[0]) + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={requests[0].request_id: 3}, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [EOS_TOKEN_ID, 10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) + + model_output = ModelRunnerOutput( + req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, + sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify request continues past EOS + self.assertEqual(len(scheduler.running), 1) + self.assertFalse(requests[0].is_finished()) + self.assertEqual(list(requests[0].output_token_ids), + [EOS_TOKEN_ID, 10, 11]) + + def test_schedule_concurrent_batches(self): + global MAX_NUM_BATCHED_TOKENS + global ENABLE_PREFIX_CACHING + global ENABLE_CHUNKED_PREFILL + global MAX_NUM_SEQS + global PROMPT_LOGPROBS + ENABLE_PREFIX_CACHING = None + MAX_NUM_BATCHED_TOKENS = 1024 + MAX_NUM_SEQS = 2 + ENABLE_CHUNKED_PREFILL = True + PROMPT_LOGPROBS = None + + enable_prefix_caching_list = [None, True] + prompt_logprobs_list = [None, 5] + + for i in range(len(enable_prefix_caching_list)): + ENABLE_PREFIX_CACHING = enable_prefix_caching_list[i] + PROMPT_LOGPROBS = prompt_logprobs_list[i] + scheduler = self.create_scheduler() + requests = create_requests( + num_requests=2, + num_tokens=512, + ) + + # Schedule the first request. + scheduler.add_request(requests[0]) + scheduler_output0 = scheduler.schedule() + self.assertEqual(len(scheduler_output0.scheduled_new_reqs), 1) + self.assertEqual( + scheduler_output0.num_scheduled_tokens[requests[0].request_id], + 512) + + # The first request is still running, so only schedule the second request. + scheduler.add_request(requests[1]) + scheduler_output1 = scheduler.schedule() + self.assertEqual(len(scheduler_output1.scheduled_new_reqs), 1) + self.assertEqual( + scheduler_output1.num_scheduled_tokens[requests[1].request_id], + 512) + + # Model output of the first request. + model_runner_output = ModelRunnerOutput( + req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, + sampled_token_ids=[[0]], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + scheduler.update_from_output(scheduler_output0, + model_runner_output) + + # Schedule the next step. + # The first request can be scheduled again while the second + # request is still running. + scheduler.schedule() + # Model output of the second request. + model_runner_output = ModelRunnerOutput( + req_ids=[requests[1].request_id], + req_id_to_index={requests[1].request_id: 0}, + sampled_token_ids=[[0]], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + scheduler.update_from_output(scheduler_output1, + model_runner_output) + + def test_schedule_spec_decoding_stats(self): + """Test scheduling behavior with speculative decoding. + + This test verifies that: + 1. Speculated tokens get scheduled correctly + 2. Spec decoding stats properly count number of draft and accepted tokens + """ + spec_tokens_list: List[List[List[int]]] = [[[1, 2, 3]], [[1, 2, 3]], + [[1, 2], [3]], [[1]], [[]], + [[1, 2, 3], [4, 5, 6]]] + output_tokens_list: List[List[List[int]]] = [[[1, 2, 3, 4]], [[1, 5]], + [[1, 2, 5], [3, 4]], + [[1, 2]], [[5]], + [[1, 2, 7], [4, 8]]] + expected_list: List[Tuple[int, int, + int, List[int]]] = [(1, 3, 3, [1, 1, 1]), + (1, 3, 1, [1, 0, 0]), + (2, 3, 3, [2, 1]), + (1, 1, 1, [1]), + (0, 0, 0, [0]), + (2, 6, 3, [2, 1, 0])] + + global NUM_SPECULATIVE_TOKENS + for idx in range(len(spec_tokens_list)): + spec_tokens = spec_tokens_list[idx] + output_tokens = output_tokens_list[idx] + expected = expected_list[idx] + num_spec_tokens = max(1, max(len(t) for t in spec_tokens)) + NUM_SPECULATIVE_TOKENS = num_spec_tokens + scheduler = self.create_scheduler() + requests = create_requests(num_requests=len(spec_tokens), + num_tokens=1) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + # Schedule a decode, which will also draft speculative tokens + output = scheduler.schedule() + self.assertEqual(len(output.scheduled_new_reqs), len(requests)) + self.assertEqual(output.total_num_scheduled_tokens, len(requests)) + for i in range(len(requests)): + req_id = requests[i].request_id + self.assertEqual(output.num_scheduled_tokens[req_id], 1) + self.assertNotIn(req_id, output.scheduled_spec_decode_tokens) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[0] for _ in range(len(requests))], + spec_token_ids=spec_tokens, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + engine_core_outputs = scheduler.update_from_output( + output, model_runner_output) + + for i in range(len(requests)): + running_req = scheduler.running[i] + # The prompt token + self.assertEqual(running_req.num_computed_tokens, 1) + # The prompt token and the sampled token + self.assertEqual(running_req.num_tokens, 2) + # The prompt token, the sampled token, and the speculated tokens + self.assertEqual(running_req.num_tokens_with_spec, + 2 + len(spec_tokens[i])) + + # No draft or accepted tokens counted yet + self.assertTrue( + not engine_core_outputs + or (engine_core_outputs[0].scheduler_stats.spec_decoding_stats + is None)) + + # Schedule the speculated tokens for validation + output = scheduler.schedule() + self.assertEqual(len(output.scheduled_new_reqs), 0) + # The sampled token and speculated tokens + self.assertEqual( + output.total_num_scheduled_tokens, + len(requests) + sum(len(ids) for ids in spec_tokens)) + for i in range(len(requests)): + req_id = requests[i].request_id + self.assertEqual(output.num_scheduled_tokens[req_id], + 1 + len(spec_tokens[i])) + if spec_tokens[i]: + self.assertEqual( + len(output.scheduled_spec_decode_tokens[req_id]), + len(spec_tokens[i])) + else: + self.assertNotIn(req_id, + output.scheduled_spec_decode_tokens) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=output_tokens, + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + engine_core_outputs = scheduler.update_from_output( + output, model_runner_output) + + scheduler_stats = engine_core_outputs[0].scheduler_stats \ + if engine_core_outputs else None + if expected[0] == 0: + self.assertIsNone(scheduler_stats.spec_decoding_stats) + else: + self.assertIsNotNone(scheduler_stats.spec_decoding_stats) + stats = scheduler_stats.spec_decoding_stats + self.assertEqual(stats.num_drafts, expected[0]) + self.assertEqual(stats.num_draft_tokens, expected[1]) + self.assertEqual(stats.num_accepted_tokens, expected[2]) + self.assertEqual(stats.num_accepted_tokens_per_pos, + expected[3]) + + def assert_scheduler_empty(self, scheduler): + """Confirm the scheduler is "empty" - i.e. no leaks.""" + # Scheduler Metadata. + scheduler = self.create_scheduler() + self.assertEqual(len(scheduler.requests), 0) + self.assertEqual(len(scheduler.waiting), 0) + self.assertEqual(len(scheduler.running), 0) + self.assertEqual(len(scheduler.finished_req_ids), 0) + + # EncoderCacheManager. + self.assertEqual(len(scheduler.encoder_cache_manager.freed), 0) + self.assertEqual(len(scheduler.encoder_cache_manager.cached), 0) + + # KVCache Manager. + self.assertEqual( + len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + req_to_blocks), 0) + self.assertEqual( + len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + num_cached_block), 0) + self.assertEqual(len(scheduler.kv_cache_manager.req_to_block_hashes), + 0) + self.assertEqual(len(scheduler.kv_cache_manager.req_to_block_hashes), + 0) + num_free_blocks = (scheduler.kv_cache_manager.block_pool. + free_block_queue.num_free_blocks) + self.assertEqual( + num_free_blocks, + scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + + # NOTE(rob): just the ref count on blocks will be 0. The hash + # value, etc will remain since we lazily evict for prefix cache. + for block in scheduler.kv_cache_manager.block_pool.blocks: + self.assertEqual(block.ref_cnt, 0) + + def test_memory_leak(self): + """Test that we do not have a memory leak.""" + scheduler = self.create_scheduler() + NUM_REQUESTS = 5 + NUM_TOKENS = 10 + MAX_TOKENS = 10 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + + # Add each request. + for request in requests: + scheduler.add_request(request) + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Iterate until done. + while True: + scheduler_output = scheduler.schedule() + if len(scheduler.running) == 0: + break + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Confirm no memory leak. + self.assert_scheduler_empty(scheduler)