Skip to content

Commit fca5a2a

Browse files
Merge branch 'main' into achatter/fp8_chunk_prefill
2 parents 1383328 + c72df80 commit fca5a2a

File tree

8 files changed

+184
-125
lines changed

8 files changed

+184
-125
lines changed

.github/workflows/pr-test-xpu.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ jobs:
5151
docker exec -w /root/sglang ci_sglang_xpu \
5252
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py test_topk_softmax.py test_flash_attention.py"
5353
54+
- name: Run Sglang Kernel Benchmarks
55+
timeout-minutes: 20
56+
run: |
57+
docker exec -w /root/sglang ci_sglang_xpu \
58+
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py "
59+
5460
- name: Run E2E Bfloat16 tests
5561
timeout-minutes: 20
5662
run: |

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@
4545
*.pyo
4646

4747
build
48+
.vscode/

Dockerfile.xpu_kernel

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,7 @@ RUN --mount=type=secret,id=github_token \
3838
conda activate py${PYTHON_VERSION} && \
3939
# . /opt/intel/oneapi/setvars.sh --force && \
4040
# Install Torch
41-
pip install torch==2.8.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu
42-
43-
# Install vllm from source
44-
RUN --mount=type=secret,id=github_token \
45-
cd /root && \
46-
. /miniforge3/bin/activate && \
47-
conda activate py${PYTHON_VERSION} && \
48-
echo "Building vllm/sglang from source ..." && \
49-
git clone https://github.com/zhuyuhua-v/vllm.git && \
50-
cd vllm && \
51-
git checkout yuhua/deepseek && \
52-
pip install setuptools_scm --root-user-action=ignore && \
53-
pip install setuptools==75.6.0 packaging==24.2 --root-user-action=ignore && \
54-
VLLM_TARGET_DEVICE=xpu python setup.py install
41+
pip install torch==2.8.0 torchvision torchaudio pytorch-triton-xpu==3.4.0 --index-url https://download.pytorch.org/whl/xpu
5542

5643
# Install SGlang from source
5744
RUN --mount=type=secret,id=github_token \
@@ -62,14 +49,13 @@ RUN --mount=type=secret,id=github_token \
6249
echo "cloning ${SG_LANG_BRANCH} from ${SG_LANG_REPO}" && \
6350
git clone --branch ${SG_LANG_BRANCH} --single-branch ${SG_LANG_REPO} && \
6451
cd sglang && \
65-
pip install -e "python[all_xpu]" --root-user-action=ignore && \
6652
# Clone sgl-kernel and build sglang-kernel...
6753
echo "cloning ${SG_LANG_KERNEL_REPO} from ${SG_LANG_KERNEL_BRANCH}" && \
6854
git clone --branch ${SG_LANG_KERNEL_BRANCH} --single-branch ${SG_LANG_KERNEL_REPO} && \
6955
cd sgl-kernel-xpu && \
7056
pip install -v . &&\
7157
# Install required packages for sglang workloads
72-
pip install msgspec blake3 py-cpuinfo compressed_tensors gguf partial_json_parser einops --root-user-action=ignore && \
58+
pip install msgspec blake3 py-cpuinfo compressed_tensors gguf partial_json_parser einops matplotlib pandas --root-user-action=ignore && \
7359
conda install libsqlite=3.48.0 -y && \
7460
echo ". /miniforge3/bin/activate; conda activate py${PYTHON_VERSION}; cd /root/" >> /root/.bashrc;
7561

benchmark/bench_flash_attn.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from itertools import product
2+
3+
import torch
4+
import triton
5+
from sgl_kernel.flash_attn import flash_attn_with_kvcache
6+
7+
8+
def flash_attn_baseline(
9+
q,
10+
k_cache,
11+
v_cache,
12+
causal,
13+
softmax_scale,
14+
cache_seqlens,
15+
page_table,
16+
cu_seqlens_q,
17+
max_seqlen_q,
18+
):
19+
"""Baseline Flash Attention implementation"""
20+
out, lse, *rest = flash_attn_with_kvcache(
21+
q,
22+
k_cache,
23+
v_cache,
24+
causal=causal,
25+
softmax_scale=softmax_scale,
26+
page_table=page_table,
27+
cache_seqlens=cache_seqlens,
28+
cu_seqlens_q=cu_seqlens_q,
29+
max_seqlen_q=max_seqlen_q,
30+
return_softmax_lse=True,
31+
)
32+
return out, lse
33+
34+
35+
# Benchmark configurations
36+
causal = [True, False]
37+
batch_size = [1, 16]
38+
q_seq_length_range = [1, 512, 1024]
39+
kv_seq_length_range = [512, 1024, 2048, 4096, 8192, 16384]
40+
page_size_range = [32, 64, 128]
41+
configs = list(
42+
product(
43+
causal, batch_size, q_seq_length_range, kv_seq_length_range, page_size_range
44+
)
45+
)
46+
47+
48+
@triton.testing.perf_report(
49+
triton.testing.Benchmark(
50+
x_names=["causal", "batch_size", "q_seq_length", "kv_seq_length", "page_size"],
51+
x_vals=[list(c) for c in configs],
52+
line_arg="provider",
53+
line_vals=["flash_attn"],
54+
line_names=["Flash Attention"],
55+
styles=[("blue", "-")],
56+
ylabel="us",
57+
plot_name="flash-attention-performance",
58+
args={},
59+
)
60+
)
61+
def benchmark(causal, batch_size, q_seq_length, kv_seq_length, page_size, provider):
62+
dtype = torch.bfloat16
63+
device = torch.device("xpu")
64+
65+
# Attention parameters
66+
num_heads = 16
67+
head_dim = 64
68+
69+
# Create input tensors
70+
q = torch.randn(
71+
(batch_size * q_seq_length, num_heads, head_dim), device=device, dtype=dtype
72+
)
73+
num_pages = (batch_size * kv_seq_length + page_size - 1) // page_size
74+
k_cache = torch.randn(
75+
(num_pages, page_size, num_heads, head_dim), device=device, dtype=dtype
76+
)
77+
v_cache = torch.randn(
78+
(num_pages, page_size, num_heads, head_dim), device=device, dtype=dtype
79+
)
80+
cache_seqlens = (
81+
torch.ones(batch_size, device=device, dtype=torch.int32) * kv_seq_length
82+
)
83+
page_table = (
84+
torch.randperm(num_pages, device=device, dtype=torch.int32)
85+
.reshape(batch_size, -1)
86+
.contiguous()
87+
)
88+
cu_seqlens_q = torch.arange(
89+
0,
90+
(batch_size + 1) * q_seq_length,
91+
step=q_seq_length,
92+
device=device,
93+
dtype=torch.int32,
94+
)
95+
max_seqlen_q = q_seq_length
96+
97+
softmax_scale = 1.0 / (head_dim**0.5)
98+
99+
quantiles = [0.5, 0.2, 0.8]
100+
101+
if provider == "flash_attn":
102+
ms, min_ms, max_ms = triton.testing.do_bench(
103+
lambda: flash_attn_baseline(
104+
q.clone(),
105+
k_cache.clone(),
106+
v_cache.clone(),
107+
causal=causal,
108+
softmax_scale=softmax_scale,
109+
cache_seqlens=cache_seqlens,
110+
page_table=page_table,
111+
cu_seqlens_q=cu_seqlens_q,
112+
max_seqlen_q=max_seqlen_q,
113+
),
114+
quantiles=quantiles,
115+
)
116+
117+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
118+
119+
120+
if __name__ == "__main__":
121+
benchmark.run(print_data=True)

include/sgl_flash_kernel_ops.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,11 @@ std::vector<at::Tensor> mha_fwd(
4848
// h_k, d) if there is page_table.
4949
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
5050
// page_size, h_k, dv) if there is page_table.
51-
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
52-
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
53-
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
54-
std::optional<int> max_seqlen_q_,
55-
std::optional<int> max_seqlen_k_,
56-
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
51+
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
52+
const at::Tensor& cu_seqlens_q, // b+1
53+
const at::Tensor& cu_seqlens_k, // b+1
54+
int max_seqlen_q,
55+
const at::Tensor& page_table,
5756
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
5857
std::optional<const at::Tensor>& leftpad_k_, // b
5958
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)

python/sgl_kernel/flash_attn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ def flash_attn_with_kvcache(
179179
max_seqlen_q = q.size(1)
180180
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
181181
if cache_seqlens is not None:
182-
max_seqlen_k = cache_seqlens.max().item()
183182
assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0)
184183
cu_seqlens_k = torch.concat(
185184
(
@@ -196,7 +195,6 @@ def flash_attn_with_kvcache(
196195
cu_seqlens_q,
197196
cu_seqlens_k,
198197
max_seqlen_q,
199-
max_seqlen_k,
200198
page_table,
201199
cache_batch_idx,
202200
cache_leftpad,

0 commit comments

Comments
 (0)