File tree Expand file tree Collapse file tree 2 files changed +26
-1
lines changed Expand file tree Collapse file tree 2 files changed +26
-1
lines changed Original file line number Diff line number Diff line change
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """Some utilities for logprobs, including logits."""
4
+
5
+ import torch
6
+
7
+
8
+ @torch .compile (dynamic = True )
9
+ def batched_count_greater_than (x : torch .Tensor ,
10
+ values : torch .Tensor ) -> torch .Tensor :
11
+ """
12
+ Counts elements in each row of x that are greater than the corresponding
13
+ value in values. Use torch.compile to generate an optimized kernel for
14
+ this function. otherwise, it will create additional copies of the input
15
+ tensors and cause memory issues.
16
+
17
+ Args:
18
+ x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
19
+ values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
20
+
21
+ Returns:
22
+ torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
23
+ """
24
+ return (x >= values ).sum (- 1 )
Original file line number Diff line number Diff line change 9
9
from vllm .v1 .outputs import LogprobsTensors , SamplerOutput
10
10
from vllm .v1 .sample .metadata import SamplingMetadata
11
11
from vllm .v1 .sample .ops .bad_words import apply_bad_words
12
+ from vllm .v1 .sample .ops .logprobs import batched_count_greater_than
12
13
from vllm .v1 .sample .ops .penalties import apply_all_penalties
13
14
from vllm .v1 .sample .ops .topk_topp_sampler import TopKTopPSampler
14
15
@@ -174,7 +175,7 @@ def gather_logprobs(
174
175
token_logprobs = logprobs .gather (- 1 , token_ids )
175
176
176
177
# Compute the ranks of the actual token.
177
- token_ranks = (logprobs >= token_logprobs ). sum ( - 1 )
178
+ token_ranks = batched_count_greater_than (logprobs , token_logprobs )
178
179
179
180
# Concatenate together with the topk.
180
181
indices = torch .cat ((token_ids , topk_indices ), dim = 1 )
You can’t perform that action at this time.
0 commit comments