Skip to content

Commit e308efd

Browse files
authored
vulkan: in flash attention, bounds check against nem1 (don't rely on GGML_KQ_MASK_PAD) (#16316)
1 parent 136bda7 commit e308efd

File tree

4 files changed

+27
-12
lines changed

4 files changed

+27
-12
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2614,8 +2614,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
26142614
const uint32_t D_lsb = D ^ (D & (D-1));
26152615
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
26162616

2617-
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
2618-
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
26192617
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
26202618
};
26212619

@@ -7457,8 +7455,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
74577455
if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
74587456
aligned = false;
74597457
}
7460-
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
7461-
GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
74627458

74637459
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
74647460

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,13 @@ void main() {
153153
}
154154

155155
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
156+
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
156157

157158
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
158159
uint32_t c = (idx + tid) % Bc;
159160
uint32_t r = (idx + tid) / Bc;
160161
if (idx + tid < Bc * Br) {
161-
if (!KV_bounds_check || j * Bc + c < KV) {
162+
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
162163
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
163164
} else {
164165
masksh[c][r] = float(0);

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,13 @@ void main() {
201201
}
202202

203203
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
204+
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
205+
204206
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
205207
uint32_t c = (idx + tid) % Bc;
206208
uint32_t r = (idx + tid) / Bc;
207209
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
208-
if (!KV_bounds_check || j * Bc + c < KV) {
210+
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
209211
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
210212
}
211213
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,31 @@ void main() {
154154
}
155155

156156
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
157-
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
158-
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
159-
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
157+
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
160158

161-
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
159+
if (nem1_bounds_check) {
160+
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
161+
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
162+
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
162163

163-
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
164+
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
164165

165-
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
166+
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
167+
168+
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
169+
} else {
170+
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
171+
// Don't clamp against nem1 when GQA is enabled
172+
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
173+
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
174+
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
175+
176+
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
177+
178+
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
179+
180+
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
181+
}
166182
}
167183

168184
// Clear padding elements to -inf, so they don't contribute to rowmax

0 commit comments

Comments
 (0)