Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,11 @@ if(BUILD_SYCL_TLA_KERNELS)
add_subdirectory(csrc/xpu/attn/xe_2)
add_subdirectory(csrc/xpu/gdn_attn/xe_2)
list(APPEND GROUPED_GEMM_LIB_NAME "grouped_gemm_xe_2")
list(APPEND ATTN_KERNEL_LIB_NAME "attn_kernels_xe_2")
list(APPEND ATTN_KERNEL_LIB_NAME "attn_prefill_kernels_xe_2")
list(APPEND ATTN_KERNEL_LIB_NAME "attn_decode_kernels_xe_2")
if(ATTN_HAS_FP8Q_LIB)
list(APPEND ATTN_KERNEL_LIB_NAME "attn_prefill_kernels_xe_2_fp8q")
endif()
list(APPEND GDN_ATTN_LIB_NAME "gdn_attn_kernels_xe_2")
list(APPEND SYCL_TLA_COMPILE_OPTIONS -DVLLM_XPU_ENABLE_XE2)
endif()
Expand Down
10 changes: 7 additions & 3 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ function(add_xe2_kernel_library LIBRARY_NAME)
cmake_parse_arguments(
PARSE_ARGV 1 ARG "INCLUDE_CMAKE_SOURCE_DIR" # Boolean options
"DESTINATION" # Single value keywords
"" # Multi-value keywords
"SOURCES" # Multi-value keywords
)

# Set default destination if not provided
Expand All @@ -560,8 +560,12 @@ function(add_xe2_kernel_library LIBRARY_NAME)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# Find all source files
file(GLOB_RECURSE KERNEL_SOURCES "*.cpp" ${ATTN_KERNEL_SRCS_GEN})
# Find source files (use explicit list when provided).
if(ARG_SOURCES)
set(KERNEL_SOURCES ${ARG_SOURCES})
else()
file(GLOB_RECURSE KERNEL_SOURCES "*.cpp" ${ATTN_KERNEL_SRCS_GEN})
endif()

# Create static library
add_library(${LIBRARY_NAME} SHARED ${KERNEL_SOURCES})
Expand Down
26 changes: 14 additions & 12 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ std::vector<at::Tensor> mha_varlen_fwd(
int max_seqlen_q,
int max_seqlen_k,
float p_dropout,
std::optional<const at::Tensor>& q_scale,
std::optional<const at::Tensor>& k_scale,
std::optional<const at::Tensor>& v_scale,
float softmax_scale,
Expand All @@ -63,21 +64,17 @@ std::vector<at::Tensor> mha_varlen_fwd(
std::optional<int> num_splits) {
auto q_type = q.scalar_type();
auto k_type = k.scalar_type();
TORCH_CHECK(
q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
"VLLM Kernel XPU only supports fp16 and bf16 type");
auto v_type = v.scalar_type();

TORCH_CHECK(
v.scalar_type() == k_type, "key and value must have the same dtype");
bool is_fp8kv = false;
if (k_type == at::ScalarType::Float8_e5m2 ||
k_type == at::ScalarType::Float8_e4m3fn) {
is_fp8kv = true;
} else {
bool is_fp8_q = q_type == at::ScalarType::Float8_e5m2 ||
q_type == at::ScalarType::Float8_e4m3fn;
bool is_fp8kv = k_type == at::ScalarType::Float8_e5m2 ||
k_type == at::ScalarType::Float8_e4m3fn;
if (is_fp8kv == is_fp8_q) {
TORCH_CHECK(
k.scalar_type() == q_type, "query and key must have the same dtype");
TORCH_CHECK(
v.scalar_type() == q_type, "query and value must have the same dtype");
}

CHECK_DEVICE(q);
Expand Down Expand Up @@ -128,6 +125,10 @@ std::vector<at::Tensor> mha_varlen_fwd(
} else {
out = torch::empty_like(q);
}
TORCH_CHECK(
out.scalar_type() == at::ScalarType::Half ||
out.scalar_type() == at::ScalarType::BFloat16,
"VLLM Kernel XPU only supports fp16 and bf16 type");

bool is_varlen = true;
bool is_local = (window_size_left != -1) | (window_size_right != -1);
Expand All @@ -147,6 +148,7 @@ std::vector<at::Tensor> mha_varlen_fwd(
seqlens_k,
max_seqlen_q,
max_seqlen_k,
q_scale,
k_scale,
v_scale,
softmax_scale,
Expand Down Expand Up @@ -241,8 +243,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cu_seqlens_q, "
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? "
"block_table, Tensor? alibi_slopes, "
"int max_seqlen_q, int max_seqlen_k, float p_dropout, Tensor? k_scale, "
"Tensor? v_scale, "
"int max_seqlen_q, int max_seqlen_k, float p_dropout, Tensor? q_scale, "
"Tensor? k_scale, Tensor? v_scale, "
"float softmax_scale, Tensor? softmax_sink, bool zero_tensors, "
"bool is_causal, int window_size_left, int window_size_right, float "
"softcap, bool return_softmax, "
Expand Down
2 changes: 2 additions & 0 deletions csrc/xpu/attn/attn_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ void cutlass_chunk_prefill_interface(
const at::Tensor& cu_seqlens_k,
int max_seqlen_q,
int max_seqlen_k,
std::optional<const at::Tensor>& q_scale,
std::optional<const at::Tensor>& k_scale,
std::optional<const at::Tensor>& v_scale,
double sm_scale,
Expand All @@ -42,6 +43,7 @@ void cutlass_chunk_prefill_interface(
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q_scale,
k_scale,
v_scale,
sm_scale,
Expand Down
1 change: 1 addition & 0 deletions csrc/xpu/attn/attn_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ void cutlass_chunk_prefill_interface(
const at::Tensor& cu_seqlens_k,
int max_seqlen_q,
int max_seqlen_k,
std::optional<const at::Tensor>& q_scale,
std::optional<const at::Tensor>& k_scale,
std::optional<const at::Tensor>& v_scale,
double sm_scale,
Expand Down
29 changes: 28 additions & 1 deletion csrc/xpu/attn/xe_2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,32 @@ fmha_forward_configure(chunk_prefill_kernel_template)

include("paged_decode_configure.cmake")
paged_decode_configure(paged_decode_kernel_template)
set(ATTN_DECODE_SRCS_GEN ${GEN_KERNEL_SRCS})

add_xe2_kernel_library(attn_kernels_xe_2 INCLUDE_CMAKE_SOURCE_DIR)
# Split generated chunk prefill instantiations to reduce peak link pressure.
set(ATTN_PREFILL_SRCS_FP8Q ${ATTN_KERNEL_SRCS_GEN_FP8Q})
set(ATTN_PREFILL_SRCS_NON_FP8Q ${ATTN_KERNEL_SRCS_GEN_NON_FP8Q})
set(ATTN_DECODE_SRCS ${ATTN_DECODE_SRCS_GEN})

# Keep runtime entry points in dedicated prefill/decode binaries.
list(APPEND ATTN_PREFILL_SRCS_NON_FP8Q ${CMAKE_CURRENT_SOURCE_DIR}/fmha_xe2.cpp)
list(APPEND ATTN_DECODE_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/paged_decode_xe2.cpp)

add_xe2_kernel_library(attn_prefill_kernels_xe_2 INCLUDE_CMAKE_SOURCE_DIR
SOURCES ${ATTN_PREFILL_SRCS_NON_FP8Q})

if(ATTN_PREFILL_SRCS_FP8Q)
add_xe2_kernel_library(
attn_prefill_kernels_xe_2_fp8q INCLUDE_CMAKE_SOURCE_DIR SOURCES
${ATTN_PREFILL_SRCS_FP8Q})
set(ATTN_HAS_FP8Q_LIB
ON
PARENT_SCOPE)
else()
set(ATTN_HAS_FP8Q_LIB
OFF
PARENT_SCOPE)
endif()

add_xe2_kernel_library(attn_decode_kernels_xe_2 INCLUDE_CMAKE_SOURCE_DIR
SOURCES ${ATTN_DECODE_SRCS})
138 changes: 31 additions & 107 deletions csrc/xpu/attn/xe_2/chunk_prefill.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct chunk_prefill_args_t {
int max_keys;
int total_seqlen_q;
int total_seqlen_k;
void* q_scale;
void* k_scale;
void* v_scale;
float sm_scale;
Expand Down Expand Up @@ -145,8 +146,9 @@ struct KernelLauncher {
stride_V,
reinterpret_cast<ElementO*>(args.out),
stride_O,
reinterpret_cast<ElementQ*>(args.sm_sink)},
reinterpret_cast<ElementO*>(args.sm_sink)},
{args.sm_scale,
args.q_scale,
args.k_scale,
args.v_scale,
static_cast<int*>(args.block_table),
Expand Down Expand Up @@ -232,9 +234,10 @@ template <
struct FMHAConfig {
static constexpr int SGTileQ =
get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})))();
// Note that always use output dtype for MMAOperation
using MMAOperation = cute::conditional_t<
is_void_v<MMAOperation_>,
XE_DPAS_TT<cute::gcd(SGTileQ, 8), float, ElementQ>,
XE_DPAS_TT<cute::gcd(SGTileQ, 8), float, ElementO>,
MMAOperation_>;
using SubgroupLayoutPV = cute::conditional_t<
is_void_v<SubgroupLayoutPV_>,
Expand Down Expand Up @@ -287,6 +290,7 @@ struct FMHAConfig {
TensorQ,
TensorK,
TensorV,
TensorO,
GmemTiledCopyQ,
GmemTiledCopyK,
GmemTiledCopyV>;
Expand Down Expand Up @@ -317,111 +321,31 @@ struct FMHAConfig {
}
};

template <typename chunk_policy, bool Paged, bool Causal, bool Local, bool Sink>
template <
typename chunk_policy,
typename ElementQ,
typename ElementKV,
typename ElementO,
bool Paged,
bool Causal,
bool Local,
bool Sink>
void policy_dispatch_impl(
sycl::queue& queue,
CutlassQKType& cuQKType,
const chunk_prefill_args_t& args) {
sycl::queue& queue, const chunk_prefill_args_t& args) {
const int PipelineStages = 2;
if (cuQKType.q_type == CutlassDType::half) {
if (cuQKType.k_type == CutlassDType::half) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
half_t,
half_t,
half_t,
half_t>::kernel_dispatch(queue, args);
} else if (cuQKType.k_type == CutlassDType::float8_e4m3) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
half_t,
float_e4m3_t,
float_e4m3_t,
half_t>::kernel_dispatch(queue, args);
} else if (cuQKType.k_type == CutlassDType::float8_e5m2) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
half_t,
float_e5m2_t,
float_e5m2_t,
half_t>::kernel_dispatch(queue, args);
}
} else {
if (cuQKType.k_type == CutlassDType::bfloat16) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
bfloat16_t,
bfloat16_t,
bfloat16_t,
bfloat16_t>::kernel_dispatch(queue, args);
} else if (cuQKType.k_type == CutlassDType::float8_e4m3) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
bfloat16_t,
float_e4m3_t,
float_e4m3_t,
bfloat16_t>::kernel_dispatch(queue, args);
} else if (cuQKType.k_type == CutlassDType::float8_e5m2) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
bfloat16_t,
float_e5m2_t,
float_e5m2_t,
bfloat16_t>::kernel_dispatch(queue, args);
}
}
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
ElementQ,
ElementKV,
ElementKV,
ElementO>::kernel_dispatch(queue, args);
}
Loading
Loading