Skip to content

Commit 15dcd15

Browse files
bigPYJ1151Zhuul
authored andcommitted
[CPU] update torch 2.8 and fix missing fields in TorchSDPAMetadata (vllm-project#25652)
Signed-off-by: jiang1.li <[email protected]>
1 parent e7e4bfb commit 15dcd15

File tree

7 files changed

+59
-53
lines changed

7 files changed

+59
-53
lines changed

.buildkite/scripts/hardware_ci/run-cpu-test.sh

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,8 @@ function cpu_tests() {
5858
# pytest -x -v -s tests/kernels/attention/test_cache.py -m cpu_model
5959
# pytest -x -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
6060
61-
# Note: disable Bart until supports V1
62-
pytest -x -v -s tests/models/language/generation -m cpu_model \
63-
--ignore=tests/models/language/generation/test_bart.py
64-
VLLM_CPU_SGL_KERNEL=1 pytest -x -v -s tests/models/language/generation -m cpu_model \
65-
--ignore=tests/models/language/generation/test_bart.py
61+
pytest -x -v -s tests/models/language/generation -m cpu_model
62+
VLLM_CPU_SGL_KERNEL=1 pytest -x -v -s tests/models/language/generation -m cpu_model
6663
6764
pytest -x -v -s tests/models/language/pooling -m cpu_model
6865
pytest -x -v -s tests/models/multimodal/generation \

docker/Dockerfile.cpu

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,6 @@ WORKDIR /workspace/vllm
114114
RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
115115
cp requirements/test.in requirements/cpu-test.in && \
116116
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
117-
sed -i 's/^torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
118-
sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
119-
sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
120117
uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu
121118

122119
RUN --mount=type=cache,target=/root/.cache/uv \

requirements/cpu-build.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
# Temporarily used for x86 CPU backend to avoid performance regression of torch>2.6.0+cpu,
2-
# see https://github.com/pytorch/pytorch/pull/151218
31
cmake>=3.26.1
42
ninja
53
packaging>=24.2
64
setuptools>=77.0.3,<80.0.0
75
setuptools-scm>=8
86
--extra-index-url https://download.pytorch.org/whl/cpu
9-
torch==2.6.0+cpu; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218
7+
torch==2.8.0+cpu; platform_machine == "x86_64"
108
torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin"
119
wheel
1210
jinja2>=3.1.6

requirements/cpu.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ numba == 0.61.2; python_version > '3.9' and platform_machine != "s390x"
88
packaging>=24.2
99
setuptools>=77.0.3,<80.0.0
1010
--extra-index-url https://download.pytorch.org/whl/cpu
11-
torch==2.6.0+cpu; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218
11+
torch==2.8.0+cpu; platform_machine == "x86_64"
1212
torch==2.8.0; platform_system == "Darwin"
1313
torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64"
1414

@@ -23,7 +23,7 @@ datasets # for benchmark scripts
2323

2424
# Intel Extension for PyTorch, only for x86_64 CPUs
2525
intel-openmp==2024.2.1; platform_machine == "x86_64"
26-
intel_extension_for_pytorch==2.6.0; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218
26+
intel_extension_for_pytorch==2.8.0; platform_machine == "x86_64"
2727
triton==3.2.0; platform_machine == "x86_64" # Triton is required for torch 2.6+cpu, as it is imported in torch.compile.
2828

2929
# Use this to gather CPU info and optimize based on ARM Neoverse cores

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,19 @@ def use_cascade_attention(*args, **kwargs) -> bool:
8585

8686
@dataclass
8787
class TorchSDPAMetadata(AttentionMetadata):
88+
"""Attention metadata for prefill and decode batched together."""
89+
# Total number of prefill requests.
90+
num_prefills: int
91+
# Number of prefill tokens.
92+
num_prefill_tokens: int
93+
# Number of decode tokens. Note that it is equivalent to the number of
94+
# decode requests.
95+
num_decode_tokens: int
96+
# (num_tokens,). The indices of the token slots that input tokens will be
97+
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
98+
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
99+
# in block 0, and 1st slot in block 1, respectively.
100+
slot_mapping: torch.Tensor
88101
"""Metadata for PagedAttention."""
89102
# (batch_size,). The length of sequences (entire tokens seen so far) per
90103
# sequence.
@@ -420,7 +433,6 @@ def build(self,
420433
num_prompt_req], # prefill
421434
query_start_loc=query_start_loc_cpu[:num_reqs +
422435
1], # for logits index
423-
enable_kv_scales_calculation=False,
424436
)
425437

426438
return attn_metadata

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
6868
"native implementation of top-p & top-k sampling. For the "
6969
"best performance, please install FlashInfer.")
7070
self.forward = self.forward_native
71+
elif current_platform.is_cpu():
72+
self.forward = self.forward_cpu
7173
else:
7274
self.forward = self.forward_native
7375

@@ -119,6 +121,45 @@ def forward_cuda(
119121
# because of slicing operation in logits_processor.
120122
return flashinfer_sample(logits.contiguous(), k, p, generators), None
121123

124+
def forward_cpu(
125+
self,
126+
logits: torch.Tensor,
127+
generators: dict[int, torch.Generator],
128+
k: Optional[torch.Tensor],
129+
p: Optional[torch.Tensor],
130+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
131+
"""
132+
PyTorch-native implementation of top-k and top-p sampling for CPU.
133+
134+
The logits tensor may be updated in-place.
135+
"""
136+
logits = self.apply_top_k_top_p(logits, k, p)
137+
logits_to_return = None
138+
if self.logprobs_mode == "processed_logits":
139+
logits_to_return = logits
140+
elif self.logprobs_mode == "processed_logprobs":
141+
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
142+
143+
# Note: this is a workaround for
144+
# https://github.com/pytorch/pytorch/pull/151218
145+
@torch.compile(dynamic=True)
146+
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
147+
probs = logits.softmax(dim=-1, dtype=torch.float32)
148+
q = torch.empty_like(probs)
149+
q.exponential_()
150+
return probs.div(q).argmax(dim=-1).view(-1)
151+
152+
if len(generators) != logits.shape[0]:
153+
return compiled_random_sample(logits), logits_to_return
154+
else:
155+
probs = logits.softmax(dim=-1, dtype=torch.float32)
156+
q = torch.empty_like(probs)
157+
q.exponential_()
158+
for i, generator in generators.items():
159+
q[i].exponential_(generator=generator)
160+
161+
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
162+
122163

123164
def apply_top_k_top_p(
124165
logits: torch.Tensor,

vllm/v1/worker/cpu_worker.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,13 @@
88

99
from vllm import envs
1010
from vllm.config import VllmConfig
11-
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
1211
from vllm.logger import init_logger
1312
from vllm.model_executor.utils import set_random_seed
1413
from vllm.platforms import CpuArchEnum, current_platform
1514
from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo
16-
from vllm.sequence import IntermediateTensors
17-
from vllm.v1.core.sched.output import SchedulerOutput
18-
from vllm.v1.outputs import ModelRunnerOutput
1915
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
2016
from vllm.v1.worker.gpu_worker import (Worker,
2117
init_worker_distributed_environment)
22-
from vllm.v1.worker.utils import is_residual_scattered_for_sp
2318

2419
logger = init_logger(__name__)
2520

@@ -102,40 +97,6 @@ def compile_or_warm_up_model(self) -> None:
10297
set_random_seed(self.model_config.seed)
10398
self.model_runner.warming_up_model()
10499

105-
@torch.inference_mode()
106-
def execute_model(
107-
self,
108-
scheduler_output: "SchedulerOutput",
109-
) -> Optional[ModelRunnerOutput]:
110-
intermediate_tensors = None
111-
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
112-
num_input_tokens = self.model_runner._get_num_input_tokens(
113-
num_scheduled_tokens)
114-
all_gather_tensors = {
115-
"residual":
116-
not is_residual_scattered_for_sp(self.vllm_config,
117-
num_input_tokens)
118-
}
119-
if not get_pp_group().is_first_rank:
120-
intermediate_tensors = IntermediateTensors(
121-
get_pp_group().recv_tensor_dict(
122-
all_gather_group=get_tp_group(),
123-
all_gather_tensors=all_gather_tensors))
124-
125-
output = self.model_runner.execute_model(scheduler_output,
126-
intermediate_tensors)
127-
128-
if not get_pp_group().is_last_rank:
129-
assert isinstance(output, IntermediateTensors)
130-
get_pp_group().send_tensor_dict(
131-
output.tensors,
132-
all_gather_group=get_tp_group(),
133-
all_gather_tensors=all_gather_tensors)
134-
return None
135-
136-
assert isinstance(output, ModelRunnerOutput)
137-
return output if self.is_driver_worker else None
138-
139100
def _get_autobind_cpu_ids(
140101
self, cpu_selector: Callable[[list[LogicalCPUInfo]],
141102
list[LogicalCPUInfo]]

0 commit comments

Comments
 (0)