Skip to content

Commit 8d0a01a

Browse files
authored
[v1][sampler] Inplace logprobs comparison to get the token rank (#21283)
Signed-off-by: Lu Fang <[email protected]>
1 parent 0ec82ed commit 8d0a01a

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

vllm/v1/sample/ops/logprobs.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Some utilities for logprobs, including logits."""
4+
5+
import torch
6+
7+
8+
@torch.compile(dynamic=True)
9+
def batched_count_greater_than(x: torch.Tensor,
10+
values: torch.Tensor) -> torch.Tensor:
11+
"""
12+
Counts elements in each row of x that are greater than the corresponding
13+
value in values. Use torch.compile to generate an optimized kernel for
14+
this function. otherwise, it will create additional copies of the input
15+
tensors and cause memory issues.
16+
17+
Args:
18+
x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
19+
values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
20+
21+
Returns:
22+
torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
23+
"""
24+
return (x >= values).sum(-1)

vllm/v1/sample/sampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
1010
from vllm.v1.sample.metadata import SamplingMetadata
1111
from vllm.v1.sample.ops.bad_words import apply_bad_words
12+
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
1213
from vllm.v1.sample.ops.penalties import apply_all_penalties
1314
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
1415

@@ -174,7 +175,7 @@ def gather_logprobs(
174175
token_logprobs = logprobs.gather(-1, token_ids)
175176

176177
# Compute the ranks of the actual token.
177-
token_ranks = (logprobs >= token_logprobs).sum(-1)
178+
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
178179

179180
# Concatenate together with the topk.
180181
indices = torch.cat((token_ids, topk_indices), dim=1)

0 commit comments

Comments
 (0)