Skip to content

Commit 7b210ff

Browse files
committed
Merge branch 'main' into dev/pralay-das/cutlass_mla_get_workspace_size
1 parent 3f8ced2 commit 7b210ff

File tree

9 files changed

+1050
-31
lines changed

9 files changed

+1050
-31
lines changed

.github/workflows/pr-test-xpu.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ jobs:
5151
docker exec -w /root/sglang ci_sglang_xpu \
5252
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py test_topk_softmax.py test_flash_attention.py"
5353
54+
- name: Run Sglang Kernel Benchmarks
55+
timeout-minutes: 20
56+
run: |
57+
docker exec -w /root/sglang ci_sglang_xpu \
58+
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py "
59+
5460
- name: Run E2E Bfloat16 tests
5561
timeout-minutes: 20
5662
run: |

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@
4545
*.pyo
4646

4747
build
48+
.vscode/

Dockerfile.xpu_kernel

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
# Set default Ubuntu version to 24.04
7-
FROM intel/deep-learning-essentials:2025.1.3-0-devel-ubuntu24.04
7+
FROM intel/deep-learning-essentials:2025.2.2-0-devel-ubuntu24.04
88

99
ENV DEBIAN_FRONTEND=noninteractive
1010

@@ -38,20 +38,7 @@ RUN --mount=type=secret,id=github_token \
3838
conda activate py${PYTHON_VERSION} && \
3939
# . /opt/intel/oneapi/setvars.sh --force && \
4040
# Install Torch
41-
pip install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu
42-
43-
# Install vllm from source
44-
RUN --mount=type=secret,id=github_token \
45-
cd /root && \
46-
. /miniforge3/bin/activate && \
47-
conda activate py${PYTHON_VERSION} && \
48-
echo "Building vllm/sglang from source ..." && \
49-
git clone https://github.com/zhuyuhua-v/vllm.git && \
50-
cd vllm && \
51-
git checkout yuhua/deepseek && \
52-
pip install setuptools_scm --root-user-action=ignore && \
53-
pip install setuptools==75.6.0 packaging==24.2 --root-user-action=ignore && \
54-
VLLM_TARGET_DEVICE=xpu python setup.py install
41+
pip install torch==2.9.0 torchvision torchaudio pytorch-triton-xpu==3.5.0 --index-url https://download.pytorch.org/whl/xpu
5542

5643
# Install SGlang from source
5744
RUN --mount=type=secret,id=github_token \
@@ -62,14 +49,13 @@ RUN --mount=type=secret,id=github_token \
6249
echo "cloning ${SG_LANG_BRANCH} from ${SG_LANG_REPO}" && \
6350
git clone --branch ${SG_LANG_BRANCH} --single-branch ${SG_LANG_REPO} && \
6451
cd sglang && \
65-
pip install -e "python[all_xpu]" --root-user-action=ignore && \
6652
# Clone sgl-kernel and build sglang-kernel...
6753
echo "cloning ${SG_LANG_KERNEL_REPO} from ${SG_LANG_KERNEL_BRANCH}" && \
6854
git clone --branch ${SG_LANG_KERNEL_BRANCH} --single-branch ${SG_LANG_KERNEL_REPO} && \
6955
cd sgl-kernel-xpu && \
7056
pip install -v . &&\
7157
# Install required packages for sglang workloads
72-
pip install msgspec blake3 py-cpuinfo compressed_tensors gguf partial_json_parser einops --root-user-action=ignore && \
58+
pip install msgspec blake3 py-cpuinfo compressed_tensors gguf partial_json_parser einops matplotlib pandas --root-user-action=ignore && \
7359
conda install libsqlite=3.48.0 -y && \
7460
echo ". /miniforge3/bin/activate; conda activate py${PYTHON_VERSION}; cd /root/" >> /root/.bashrc;
7561

benchmark/bench_flash_attn.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
from itertools import product
2+
3+
import torch
4+
import triton
5+
from sgl_kernel.flash_attn import flash_attn_with_kvcache
6+
7+
8+
def flash_attn_baseline(
9+
q,
10+
k_cache,
11+
v_cache,
12+
causal,
13+
window_size,
14+
softmax_scale,
15+
softmax_sink,
16+
cache_seqlens,
17+
page_table,
18+
cu_seqlens_q,
19+
max_seqlen_q,
20+
):
21+
"""Baseline Flash Attention implementation"""
22+
out, lse, *rest = flash_attn_with_kvcache(
23+
q,
24+
k_cache,
25+
v_cache,
26+
causal=causal,
27+
softmax_sink=softmax_sink,
28+
window_size=window_size,
29+
softmax_scale=softmax_scale,
30+
page_table=page_table,
31+
cache_seqlens=cache_seqlens,
32+
cu_seqlens_q=cu_seqlens_q,
33+
max_seqlen_q=max_seqlen_q,
34+
return_softmax_lse=True,
35+
)
36+
return out, lse
37+
38+
39+
# Benchmark configurations
40+
causal = [True, False]
41+
local = [True, False]
42+
use_softmax_sink = [True, False]
43+
batch_size = [1, 16]
44+
q_seq_length_range = [1, 512, 1024]
45+
kv_seq_length_range = [512, 1024, 2048, 4096, 8192, 16384]
46+
page_size_range = [32, 64, 128]
47+
configs = list(
48+
filter(
49+
lambda cfg: not (cfg[0] and cfg[1]),
50+
product(
51+
causal,
52+
local,
53+
use_softmax_sink,
54+
batch_size,
55+
q_seq_length_range,
56+
kv_seq_length_range,
57+
page_size_range,
58+
),
59+
)
60+
)
61+
62+
63+
@triton.testing.perf_report(
64+
triton.testing.Benchmark(
65+
x_names=[
66+
"causal",
67+
"local",
68+
"use_softmax_sink",
69+
"batch_size",
70+
"q_seq_length",
71+
"kv_seq_length",
72+
"page_size",
73+
],
74+
x_vals=[list(c) for c in configs],
75+
line_arg="provider",
76+
line_vals=["flash_attn"],
77+
line_names=["Flash Attention"],
78+
styles=[("blue", "-")],
79+
ylabel="us",
80+
plot_name="flash-attention-performance",
81+
args={},
82+
)
83+
)
84+
def benchmark(
85+
causal,
86+
local,
87+
use_softmax_sink,
88+
batch_size,
89+
q_seq_length,
90+
kv_seq_length,
91+
page_size,
92+
provider,
93+
):
94+
dtype = torch.bfloat16
95+
device = torch.device("xpu")
96+
97+
# Attention parameters
98+
num_heads = 16
99+
head_dim = 64
100+
101+
# Create input tensors
102+
q = torch.randn(
103+
(batch_size * q_seq_length, num_heads, head_dim), device=device, dtype=dtype
104+
)
105+
num_pages = (batch_size * kv_seq_length + page_size - 1) // page_size
106+
k_cache = torch.randn(
107+
(num_pages, page_size, num_heads, head_dim), device=device, dtype=dtype
108+
)
109+
v_cache = torch.randn(
110+
(num_pages, page_size, num_heads, head_dim), device=device, dtype=dtype
111+
)
112+
cache_seqlens = (
113+
torch.ones(batch_size, device=device, dtype=torch.int32) * kv_seq_length
114+
)
115+
page_table = (
116+
torch.randperm(num_pages, device=device, dtype=torch.int32)
117+
.reshape(batch_size, -1)
118+
.contiguous()
119+
)
120+
cu_seqlens_q = torch.arange(
121+
0,
122+
(batch_size + 1) * q_seq_length,
123+
step=q_seq_length,
124+
device=device,
125+
dtype=torch.int32,
126+
)
127+
max_seqlen_q = q_seq_length
128+
window_size = (-1, -1) if not local else torch.randint(0, kv_seq_length, (2,))
129+
130+
softmax_sink = (
131+
torch.randn(num_heads, device=device, dtype=dtype) if use_softmax_sink else None
132+
)
133+
134+
softmax_scale = 1.0 / (head_dim**0.5)
135+
136+
quantiles = [0.5, 0.2, 0.8]
137+
138+
if provider == "flash_attn":
139+
ms, min_ms, max_ms = triton.testing.do_bench(
140+
lambda: flash_attn_baseline(
141+
q.clone(),
142+
k_cache.clone(),
143+
v_cache.clone(),
144+
causal=causal,
145+
window_size=window_size,
146+
softmax_scale=softmax_scale,
147+
softmax_sink=softmax_sink,
148+
cache_seqlens=cache_seqlens,
149+
page_table=page_table,
150+
cu_seqlens_q=cu_seqlens_q,
151+
max_seqlen_q=max_seqlen_q,
152+
),
153+
quantiles=quantiles,
154+
)
155+
156+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
157+
158+
159+
if __name__ == "__main__":
160+
benchmark.run(print_data=True)
161+
print("Benchmark finished!")

cmake/BuildFlags.cmake

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,13 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
7373
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -no-ftz)
7474
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -fno-sycl-instrument-device-code)
7575
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -Xspirv-translator)
76-
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -spirv-ext=+SPV_INTEL_split_barrier) #,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate)
76+
77+
# SYCL compiler in basekit after 2025.2 needs more spirv arguments.
78+
if(SYCL_COMPILER_VERSION GREATER_EQUAL 20250806)
79+
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate)
80+
else()
81+
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -spirv-ext=+SPV_INTEL_split_barrier)
82+
endif()
7783

7884
if(CMAKE_BUILD_TYPE MATCHES Debug)
7985
set(SYCL_KERNEL_OPTIONS ${SYCL_KERNEL_OPTIONS} -g -O0 -Rno-debug-disables-optimization)

include/sgl_flash_kernel_ops.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,11 @@ std::vector<at::Tensor> mha_fwd(
4848
// h_k, d) if there is page_table.
4949
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
5050
// page_size, h_k, dv) if there is page_table.
51-
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
52-
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
53-
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
54-
std::optional<int> max_seqlen_q_,
55-
std::optional<int> max_seqlen_k_,
56-
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
51+
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
52+
const at::Tensor& cu_seqlens_q, // b+1
53+
const at::Tensor& cu_seqlens_k, // b+1
54+
int max_seqlen_q,
55+
const at::Tensor& page_table,
5756
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
5857
std::optional<const at::Tensor>& leftpad_k_, // b
5958
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)

python/sgl_kernel/flash_attn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ def flash_attn_with_kvcache(
179179
max_seqlen_q = q.size(1)
180180
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
181181
if cache_seqlens is not None:
182-
max_seqlen_k = cache_seqlens.max().item()
183182
assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0)
184183
cu_seqlens_k = torch.concat(
185184
(
@@ -196,7 +195,6 @@ def flash_attn_with_kvcache(
196195
cu_seqlens_q,
197196
cu_seqlens_k,
198197
max_seqlen_q,
199-
max_seqlen_k,
200198
page_table,
201199
cache_batch_idx,
202200
cache_leftpad,

0 commit comments

Comments
 (0)