Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,9 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
set(SYCL_LINK_FLAGS "")
list(APPEND SYCL_LINK_FLAGS "-fsycl")
set(SYCL_DEVICE_LINK_FLAGS ${SYCL_LINK_FLAGS})
set(SYCL_DEVICE_LINK_FLAGS ${SYCL_DEVICE_LINK_FLAGS}
-fsycl-max-parallel-link-jobs=16)
set(SYCL_DEVICE_LINK_FLAGS
${SYCL_DEVICE_LINK_FLAGS} -fsycl-max-parallel-link-jobs=16
-flink-huge-device-code)
set(SYCL_DEVICE_LINK_FLAGS
${SYCL_DEVICE_LINK_FLAGS}
"-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate"
Expand Down
5 changes: 4 additions & 1 deletion csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ std::vector<at::Tensor> mha_varlen_fwd(
int block_size = k.size(1);

int num_kv_splits = num_splits.value_or(get_num_splits(
queue, batch_size, num_heads_kv, effective_seqlen_k, block_size));
queue, batch_size, num_heads_kv, max_seqlen_k, block_size));
// Cap num_kv_splits to block_size which equals the decode kernel's
// max_num_kv_splits (SGPerWG * sg_size) for all page-size policies.
num_kv_splits = std::min(num_kv_splits, block_size);

at::Tensor tmp_out =
num_kv_splits == 1
Expand Down
30 changes: 30 additions & 0 deletions csrc/xpu/attn/xe_2/chunk_prefill_configure.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,36 @@ function(fmha_forward_configure FILENAME_SUFFIX)
endforeach()
endforeach()

# Page-size-specific paged-only policies (only Paged=true)
set(paged_policy_list
"chunk_policy_head64_p16"
"chunk_policy_head96_p16"
"chunk_policy_head128_p16"
"chunk_policy_head192_p16"
"chunk_policy_head256_p16"
"chunk_policy_head96_p32"
"chunk_policy_head128_p32")

set(IMPL_KISPAGED "true")
foreach(IMPL_POLICY ${paged_policy_list})
foreach(IMPL_KISCAUSAL ${L_BOOLS})
foreach(IMPL_KISLOCAL ${L_BOOLS})
foreach(IMPL_KISSINK ${L_BOOLS})
set(FILE_SUFFIX "${IMPL_POLICY}_")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISPAGED}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}")
configure_file(${FILENAME_SUFFIX}.cpp.in
"${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp")
list(
APPEND GEN_KERNEL_SRCS
"${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp")
endforeach()
endforeach()
endforeach()
endforeach()

list(REMOVE_DUPLICATES GEN_KERNEL_SRCS)
list(LENGTH GEN_KERNEL_SRCS GEN_KERNEL_SRCS_LENGTH)
message(
Expand Down
18 changes: 18 additions & 0 deletions csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,24 @@
// Apply the bool combination generator to all policies
CHUNK_POLICY_LIST(DECLARE_ALL_BOOL_COMBINATIONS)

// Page-size-specific paged-only policies (only Paged=true instantiations)
#define CHUNK_PAGED_POLICY_LIST(X) \
X(chunk_policy_head64_p16) \
X(chunk_policy_head96_p16) \
X(chunk_policy_head128_p16) \
X(chunk_policy_head192_p16) \
X(chunk_policy_head256_p16) \
X(chunk_policy_head96_p32) \
X(chunk_policy_head128_p32)

#define DECLARE_ALL_BOOL_PAGED_ONLY(POLICY) DECLARE_FOR_CAUSAL(POLICY, true)

CHUNK_PAGED_POLICY_LIST(DECLARE_ALL_BOOL_PAGED_ONLY)

// Cleanup macros
#undef DECLARE_ALL_BOOL_PAGED_ONLY
#undef CHUNK_PAGED_POLICY_LIST

// Cleanup macros
#undef DECLARE_ALL_BOOL_COMBINATIONS
#undef DECLARE_FOR_CAUSAL
Expand Down
111 changes: 111 additions & 0 deletions csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,117 @@ void policy_dispatch_func(
}
}

// Dispatch by head size for non-paged or page_size >= 64 paths.
// Paged=false is passed as a bool arg so both paged and non-paged can use it.
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says Paged=false is passed as a bool arg, but this is a template parameter (template <bool IsPaged>), not a runtime bool argument. Update the comment to avoid confusing readers (e.g., clarify that paged-ness is a compile-time template parameter for selecting instantiations).

Suggested change
// Paged=false is passed as a bool arg so both paged and non-paged can use it.
// Paged-ness is selected via the compile-time template parameter IsPaged so
// both paged and non-paged paths can share this implementation.

Copilot uses AI. Check for mistakes.
template <bool IsPaged>
void dispatch_by_head_size_default(
sycl::queue& queue,
CutlassQKType& cuQKType,
const chunk_prefill_args_t& args,
bool is_causal,
bool is_local,
bool is_sink) {
if (args.head_size <= HEAD_SIZE_LIMIT_0) {
policy_dispatch_func<chunk_policy_head64, IsPaged>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_1) {
policy_dispatch_func<chunk_policy_head96, IsPaged>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_2) {
policy_dispatch_func<chunk_policy_head128, IsPaged>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_3) {
policy_dispatch_func<chunk_policy_head192, IsPaged>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_4) {
policy_dispatch_func<chunk_policy_head256, IsPaged>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else {
TORCH_CHECK(false, "Unsupported head size for fmha");
}
Comment on lines +39 to +56
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The head-size dispatch ladder is duplicated across dispatch_by_head_size_default, dispatch_by_head_size_p16, and dispatch_by_head_size_p32, increasing the chance of future inconsistencies (e.g., updating limits/policies in one path but not others). Consider factoring the common ladder into a single helper that takes the per-head policy types as template parameters (or a small traits mapping), so only the policy selection differs while the control flow stays in one place.

Copilot uses AI. Check for mistakes.
}

// Dispatch by head size for paged KV with page_size=32.
// head96/128 need a p32 policy (K-tile=32); others fall back to default.
inline void dispatch_by_head_size_p32(
sycl::queue& queue,
CutlassQKType& cuQKType,
const chunk_prefill_args_t& args,
bool is_causal,
bool is_local,
bool is_sink) {
if (args.head_size <= HEAD_SIZE_LIMIT_0) {
policy_dispatch_func<chunk_policy_head64, true>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_1) {
policy_dispatch_func<chunk_policy_head96_p32, true>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_2) {
policy_dispatch_func<chunk_policy_head128_p32, true>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_3) {
policy_dispatch_func<chunk_policy_head192, true>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_4) {
policy_dispatch_func<chunk_policy_head256, true>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else {
TORCH_CHECK(false, "Unsupported head size for fmha");
}
}

// Dispatch by head size for paged KV with page_size=16 (K-tile=16 for all).
inline void dispatch_by_head_size_p16(
sycl::queue& queue,
CutlassQKType& cuQKType,
const chunk_prefill_args_t& args,
bool is_causal,
bool is_local,
bool is_sink) {
if (args.head_size <= HEAD_SIZE_LIMIT_0) {
policy_dispatch_func<chunk_policy_head64_p16, true>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_1) {
policy_dispatch_func<chunk_policy_head96_p16, true>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_2) {
policy_dispatch_func<chunk_policy_head128_p16, true>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_3) {
policy_dispatch_func<chunk_policy_head192_p16, true>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_4) {
policy_dispatch_func<chunk_policy_head256_p16, true>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else {
TORCH_CHECK(false, "Unsupported head size for fmha");
}
}

// Top-level dispatch: select head-size dispatch table by page size.
inline void dispatch_by_page_size(
sycl::queue& queue,
CutlassQKType& cuQKType,
const chunk_prefill_args_t& args,
bool is_paged,
bool is_causal,
bool is_local,
bool is_sink) {
if (!is_paged) {
dispatch_by_head_size_default<false>(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.block_size < 32) {
dispatch_by_head_size_p16(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else if (args.block_size < 64) {
dispatch_by_head_size_p32(
queue, cuQKType, args, is_causal, is_local, is_sink);
} else {
dispatch_by_head_size_default<true>(
queue, cuQKType, args, is_causal, is_local, is_sink);
}
}

void cutlass_chunk_prefill_impl(
sycl::queue& queue,
const at::Tensor& query, // [seq_q, heads, head_size]
Expand Down
3 changes: 3 additions & 0 deletions csrc/xpu/attn/xe_2/collective/chunk_prefill_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,9 @@ class DecodeFwdEpilogue {

if constexpr (ReduceK{} == _1{}) {
ReduceFragARow rA_max;
// Initialize rA_max from tA_max so that max_logits is correct
// when num_kv_splits > 1 (used by ReduceSplitK).
rA_max(0) = tA_max(0);
Comment on lines 588 to +591
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReduceFragARow rA_max; is default-initialized and only element 0 is set. If ReduceFragARow contains more than one element, the remaining values are uninitialized and can corrupt ReduceSplitK reductions. Initialize the entire fragment deterministically (e.g., fill from tA_max for all lanes/elements, or value-initialize then assign all required indices).

Suggested change
ReduceFragARow rA_max;
// Initialize rA_max from tA_max so that max_logits is correct
// when num_kv_splits > 1 (used by ReduceSplitK).
rA_max(0) = tA_max(0);
ReduceFragARow rA_max{};
// Initialize rA_max from tA_max so that max_logits is correct
// when num_kv_splits > 1 (used by ReduceSplitK).
for (int i = 0; i < cute::size(rA_max); ++i) {
rA_max(i) = tA_max(0);
}

Copilot uses AI. Check for mistakes.
return std::make_tuple(tArA, rA_max, tA_sum, true);
} else {
/* Identify A tile ID and k block for this subgroup. */
Expand Down
76 changes: 75 additions & 1 deletion csrc/xpu/attn/xe_2/fmha_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,88 @@ struct chunk_policy_head256 {
using SubgroupLayoutQK = Layout<Shape<_32, _1, _1>>;
};

// Page-size-aware chunk prefill policies.
// Used when paged KV has block_size smaller than the default K-tile.

// block_size=16 variants (K-tile = 16)
struct chunk_policy_head64_p16 {
using ShapeQK = Shape<_128, _16, _32>;
using ShapePV = Shape<_128, _32, _16>;
using ShapeOut = Shape<_128, _64>;
using SubgroupLayoutQK = Layout<Shape<_8, _1, _1>>;
};

struct chunk_policy_head96_p16 {
using ShapeQK = Shape<_128, _16, _32>;
using ShapePV = Shape<_128, _32, _16>;
using ShapeOut = Shape<_128, _96>;
using SubgroupLayoutQK = Layout<Shape<_8, _1, _1>>;
};

struct chunk_policy_head128_p16 {
using ShapeQK = Shape<_128, _16, _32>;
using ShapePV = Shape<_128, _32, _16>;
using ShapeOut = Shape<_128, _128>;
using SubgroupLayoutQK = Layout<Shape<_16, _1, _1>>;
};

struct chunk_policy_head192_p16 {
using ShapeQK = Shape<_256, _16, _32>;
using ShapePV = Shape<_256, _32, _16>;
using ShapeOut = Shape<_256, _192>;
using SubgroupLayoutQK = Layout<Shape<_32, _1, _1>>;
};

struct chunk_policy_head256_p16 {
using ShapeQK = Shape<_256, _16, _32>;
using ShapePV = Shape<_256, _32, _16>;
using ShapeOut = Shape<_256, _256>;
using SubgroupLayoutQK = Layout<Shape<_32, _1, _1>>;
};

// block_size=32 variants (K-tile = 32)
// Only needed for head96/head128 whose default K-tile is 64
struct chunk_policy_head96_p32 {
using ShapeQK = Shape<_128, _32, _32>;
using ShapePV = Shape<_128, _32, _32>;
using ShapeOut = Shape<_128, _96>;
using SubgroupLayoutQK = Layout<Shape<_8, _1, _1>>;
};

struct chunk_policy_head128_p32 {
using ShapeQK = Shape<_128, _32, _32>;
using ShapePV = Shape<_128, _32, _32>;
using ShapeOut = Shape<_128, _128>;
using SubgroupLayoutQK = Layout<Shape<_16, _1, _1>>;
};

// define decode policy
template <typename q_packed, typename head_dim, typename kv_tile>
struct decode_policy_qpacked_head {
static_assert(
cute::is_same_v<kv_tile, _64> || cute::is_same_v<kv_tile, _128>,
cute::is_same_v<kv_tile, _16> || cute::is_same_v<kv_tile, _32> ||
cute::is_same_v<kv_tile, _64> || cute::is_same_v<kv_tile, _128>,
"Unsupported kv_tile(page_size) for decode_policy_qpacked_head");
};

// kv_tile == _16
template <typename q_packed, typename head_dim>
struct decode_policy_qpacked_head<q_packed, head_dim, _16> {
using ShapeQK = Shape<q_packed, _16, _64>;
using ShapePV = Shape<q_packed, _32, _16>;
using ShapeOut = Shape<q_packed, head_dim>;
using SubgroupLayoutQK = Layout<Shape<_1, _1, _1>>;
};

// kv_tile == _32
template <typename q_packed, typename head_dim>
struct decode_policy_qpacked_head<q_packed, head_dim, _32> {
using ShapeQK = Shape<q_packed, _32, _64>;
using ShapePV = Shape<q_packed, _32, _32>;
using ShapeOut = Shape<q_packed, head_dim>;
using SubgroupLayoutQK = Layout<Shape<_1, _2, _1>>;
};

// kv_tile == _64
template <typename q_packed, typename head_dim>
struct decode_policy_qpacked_head<q_packed, head_dim, _64> {
Expand Down
20 changes: 2 additions & 18 deletions csrc/xpu/attn/xe_2/fmha_xe2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,6 @@ void cutlass_chunk_prefill_impl(
"FMHA forward only supports head dimension at most " +
std::to_string(max_head_size));

if (args.head_size <= HEAD_SIZE_LIMIT_0) {
policy_dispatch_func<chunk_policy_head64>(
queue, cuQKType, args, is_paged, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_1) {
policy_dispatch_func<chunk_policy_head96>(
queue, cuQKType, args, is_paged, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_2) {
policy_dispatch_func<chunk_policy_head128>(
queue, cuQKType, args, is_paged, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_3) {
policy_dispatch_func<chunk_policy_head192>(
queue, cuQKType, args, is_paged, is_causal, is_local, is_sink);
} else if (args.head_size <= HEAD_SIZE_LIMIT_4) {
policy_dispatch_func<chunk_policy_head256>(
queue, cuQKType, args, is_paged, is_causal, is_local, is_sink);
} else {
TORCH_CHECK(false, "Unsupported head size for fmha");
}
dispatch_by_page_size(
queue, cuQKType, args, is_paged, is_causal, is_local, is_sink);
}
22 changes: 22 additions & 0 deletions csrc/xpu/attn/xe_2/paged_decode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@

using namespace cute;

using decode_policy_q8_h64_p16 = decode_policy_qpacked_head<_8, _64, _16>;
using decode_policy_q8_h96_p16 = decode_policy_qpacked_head<_8, _96, _16>;
using decode_policy_q8_h128_p16 = decode_policy_qpacked_head<_8, _128, _16>;
using decode_policy_q8_h192_p16 = decode_policy_qpacked_head<_8, _192, _16>;
using decode_policy_q8_h256_p16 = decode_policy_qpacked_head<_8, _256, _16>;
using decode_policy_q16_h64_p16 = decode_policy_qpacked_head<_16, _64, _16>;
using decode_policy_q16_h96_p16 = decode_policy_qpacked_head<_16, _96, _16>;
using decode_policy_q16_h128_p16 = decode_policy_qpacked_head<_16, _128, _16>;
using decode_policy_q16_h192_p16 = decode_policy_qpacked_head<_16, _192, _16>;
using decode_policy_q16_h256_p16 = decode_policy_qpacked_head<_16, _256, _16>;

using decode_policy_q8_h64_p32 = decode_policy_qpacked_head<_8, _64, _32>;
using decode_policy_q8_h96_p32 = decode_policy_qpacked_head<_8, _96, _32>;
using decode_policy_q8_h128_p32 = decode_policy_qpacked_head<_8, _128, _32>;
using decode_policy_q8_h192_p32 = decode_policy_qpacked_head<_8, _192, _32>;
using decode_policy_q8_h256_p32 = decode_policy_qpacked_head<_8, _256, _32>;
using decode_policy_q16_h64_p32 = decode_policy_qpacked_head<_16, _64, _32>;
using decode_policy_q16_h96_p32 = decode_policy_qpacked_head<_16, _96, _32>;
using decode_policy_q16_h128_p32 = decode_policy_qpacked_head<_16, _128, _32>;
using decode_policy_q16_h192_p32 = decode_policy_qpacked_head<_16, _192, _32>;
using decode_policy_q16_h256_p32 = decode_policy_qpacked_head<_16, _256, _32>;

using decode_policy_q8_h64_p64 = decode_policy_qpacked_head<_8, _64, _64>;
using decode_policy_q8_h96_p64 = decode_policy_qpacked_head<_8, _96, _64>;
using decode_policy_q8_h128_p64 = decode_policy_qpacked_head<_8, _128, _64>;
Expand Down
Loading
Loading