Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion csrc/xpu/attn/xe_2/collective/chunk_prefill_mainloop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ template <
class DispatchPolicy_,
bool PagedKV_,
bool CausalMask_,
bool LocalMask_,
class TiledMMAQK_, // Tiling for Q*K GEMM
class TiledMMAPV_, // Tiling for P*V GEMM
int VTiles_, // # of tiles in V dimension
Expand All @@ -568,6 +569,7 @@ template <
int Stages,
bool PagedKV_,
bool CausalMask_,
bool LocalMask_,
class TiledMMAQK_,
class TiledMMAPV_,
int VTiles_,
Expand All @@ -581,6 +583,7 @@ struct DecodeFwdMainloop<
XeDefault<Stages>,
PagedKV_,
CausalMask_,
LocalMask_,
TiledMMAQK_,
TiledMMAPV_,
VTiles_,
Expand Down Expand Up @@ -657,6 +660,7 @@ struct DecodeFwdMainloop<

static constexpr bool PagedKV = PagedKV_;
static constexpr bool CausalMask = CausalMask_;
static constexpr bool LocalMask = LocalMask_;

// User-facing arguments
struct Arguments {
Expand All @@ -666,6 +670,8 @@ struct DecodeFwdMainloop<
int page_size;
int max_pages_per_seq;
int total_seqlen_kv;
// Local Mask
int local_left, local_right;
};

// Kernel-facing parameters
Expand All @@ -691,7 +697,9 @@ struct DecodeFwdMainloop<
args.ptr_page_table,
args.page_size,
args.max_pages_per_seq,
args.total_seqlen_kv};
args.total_seqlen_kv,
args.local_left,
args.local_right};
}

CUTLASS_HOST_DEVICE static bool can_implement(Arguments const&) {
Expand Down Expand Up @@ -875,6 +883,23 @@ struct DecodeFwdMainloop<
// }
// }

if constexpr (LocalMask) {
Tensor cPgP = make_identity_tensor(make_shape(seq_len, seq_len));
Tensor gP = local_tile(
cPgP, take<0, 2>(TileShapeQK{}), make_coord(get<0>(blk_qv), K));
auto cS_thread = thr_mma_qk.partition_C(gP);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < tSrS.size(); ++i) {
int row_idx = get<0>(cS_thread(i)) - discard_seq_coord;
int col_idx = get<1>(cS_thread(i)) - full_tile_offset;
bool left_mask = col_idx < row_idx - params.local_left;
bool right_mask = col_idx > row_idx + params.local_right;
if (left_mask || right_mask) {
tSrS(i) = ElementS(-INFINITY);
}
}
}

/* k masking for remainder tiles */
if (check_remainder_k && K == blk_k1 - 1) {
FragSCol k_rem_mask;
Expand Down
5 changes: 4 additions & 1 deletion csrc/xpu/attn/xe_2/paged_decode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ struct DecodeKernelLauncher {
static_cast<int*>(args.block_table),
args.block_size,
args.max_blocks_per_seq,
args.total_seqlen_k},
args.total_seqlen_k,
args.window_size_left,
args.window_size_right},
{},
hw_info,
args.num_kv_splits};
Expand Down Expand Up @@ -375,6 +377,7 @@ struct PagedDecodeConfig {
MainloopDispatchPolicy,
Paged,
Causal,
Local,
TiledMMAQK,
TiledMMAPV,
VTiles,
Expand Down
10 changes: 7 additions & 3 deletions tests/flash_attn/test_flash_attn_varlen_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def test_varlen_with_paged_kv(
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("window_size", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
Expand All @@ -315,6 +316,7 @@ def test_decode_with_paged_kv(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
window_size: tuple[int, int],
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
Expand Down Expand Up @@ -392,7 +394,7 @@ def test_decode_with_paged_kv(
softmax_scale=scale,
causal=False,
block_table=block_tables,
window_size=(-1, -1),
window_size=window_size,
s_aux=sink)

ref_output = ref_paged_attn(query=query,
Expand All @@ -405,9 +407,11 @@ def test_decode_with_paged_kv(
casual=False,
is_paged=True,
sink=sink,
window_size_left=-1,
window_size_right=-1)
window_size_left=window_size[0],
window_size_right=window_size[1])
atol, rtol = 1e-2, 1e-2
if window_size[0] != -1 or window_size[1] != -1:
atol, rtol = 2e-2, 2e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
Expand Down