Skip to content

Commit 0167efe

Browse files
chi2liuchiliu
andauthored
[Core] Optimize scheduler request removal for single completions (#21917)
Signed-off-by: chiliu <[email protected]> Signed-off-by: chiliu <[email protected]> Co-authored-by: chiliu <[email protected]>
1 parent c32e6ad commit 0167efe

File tree

2 files changed

+39
-8
lines changed

2 files changed

+39
-8
lines changed

vllm/v1/core/sched/scheduler.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
SchedulerOutput)
2626
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
2727
create_request_queue)
28-
from vllm.v1.core.sched.utils import check_stop
28+
from vllm.v1.core.sched.utils import check_stop, remove_all
2929
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
3030
EngineCoreOutputs)
3131
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -872,9 +872,7 @@ def update_from_output(
872872

873873
# Remove the stopped requests from the running and waiting queues.
874874
if stopped_running_reqs:
875-
self.running = [
876-
req for req in self.running if req not in stopped_running_reqs
877-
]
875+
self.running = remove_all(self.running, stopped_running_reqs)
878876
if stopped_preempted_reqs:
879877
# This is a rare case and unlikely to impact performance.
880878
self.waiting.remove_requests(stopped_preempted_reqs)
@@ -1000,7 +998,7 @@ def finish_requests(
1000998
else:
1001999
request_ids = set(request_ids)
10021000

1003-
running_requests_to_remove = []
1001+
running_requests_to_remove = set()
10041002
waiting_requests_to_remove = []
10051003
valid_requests = []
10061004

@@ -1013,13 +1011,13 @@ def finish_requests(
10131011

10141012
valid_requests.append(request)
10151013
if request.status == RequestStatus.RUNNING:
1016-
running_requests_to_remove.append(request)
1014+
running_requests_to_remove.add(request)
10171015
else:
10181016
waiting_requests_to_remove.append(request)
10191017

10201018
# Remove all requests from queues at once for better efficiency
1021-
for request in running_requests_to_remove:
1022-
self.running.remove(request)
1019+
if running_requests_to_remove:
1020+
self.running = remove_all(self.running, running_requests_to_remove)
10231021
if waiting_requests_to_remove:
10241022
self.waiting.remove_requests(waiting_requests_to_remove)
10251023

vllm/v1/core/sched/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,45 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import contextlib
34
from typing import Optional
45

56
import torch
67

78
from vllm.v1.request import Request, RequestStatus
89

910

11+
def remove_all(lst: list, items_to_remove: set) -> list:
12+
"""Remove all items from a list that are in the items_to_remove set.
13+
14+
This method optimizes for the common case of removing a single item,
15+
falling back to list comprehension for multiple items.
16+
17+
Args:
18+
lst: The list to remove items from
19+
items_to_remove: Set of items to remove
20+
21+
Returns:
22+
Either the modified original list (for single item removal) or
23+
a new list (for multiple item removal). Callers should use the
24+
returned value.
25+
26+
Note:
27+
For single item removal, this modifies the original list in-place
28+
and returns it. For multiple items, it creates and returns a new list.
29+
"""
30+
if not items_to_remove:
31+
return lst
32+
33+
if len(items_to_remove) == 1:
34+
# Fast path for single item removal (most common case)
35+
item = next(iter(items_to_remove))
36+
with contextlib.suppress(ValueError):
37+
lst.remove(item)
38+
return lst
39+
# For multiple items, use list comprehension
40+
return [item for item in lst if item not in items_to_remove]
41+
42+
1043
def check_stop(request: Request,
1144
max_model_len: int,
1245
pooler_output: Optional[torch.Tensor] = None) -> bool:

0 commit comments

Comments
 (0)