Skip to content

Commit 491976e

Browse files
committed
[GPU] Limit sdpa_micro attention mask load to avoid compilation error
Earlier 3b84486 commit has caused kernel compilation error for shape of [batch, num_heads, 1, 1]. For this shape, restore to earlier tiled_load_t.
1 parent 8d83df6 commit 491976e

File tree

2 files changed

+62
-33
lines changed

2 files changed

+62
-33
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_micro.cl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,18 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
404404
#if WITH_ATTN_MASK
405405
/* Load mask. No remainder handling needed assuming k block size is a power of 2. */
406406
mask_tile_type mask_tile;
407-
const uint mask_m = MSK_D2;
408-
const uint mask_n = MSK_D3;
409-
// Check if attention mask has a single Query dimension (e.g., [batch, num_heads, 1, sequence_length])
410-
// In the case of single query dimension, set ld and offset_r to zero
411-
// to avoid exceeding bounds for single dimension.
412-
const uint mask_ld = (mask_m == 1)? 0 : mask_n;
413-
const uint mask_offset_r = (mask_m == 1)? 0 : sg_j0_kq + wg_j0;
414-
tile_load_t(&mask_tile, msk, mask_m, mask_n, mask_ld, mask_offset_r, k0 + sg_i0_kq);
407+
if (MSK_D2 == 1 && MSK_D3 > 1) {
408+
// Check if attention mask has a single Query dimension (e.g., [batch, num_heads, 1, sequence_length])
409+
// In the case of single query dimension, set ld and offset_r to zero
410+
// to avoid exceeding bounds for single dimension.
411+
const uint mask_m = MSK_D2;
412+
const uint mask_n = MSK_D3;
413+
const uint mask_ld = (mask_m == 1)? 0 : mask_n;
414+
const uint mask_offset_r = (mask_m == 1)? 0 : sg_j0_kq + wg_j0;
415+
tile_load_t(&mask_tile, msk, mask_m, mask_n, mask_ld, mask_offset_r, k0 + sg_i0_kq);
416+
} else {
417+
tile_load_t(&mask_tile, msk, q, k, sg_j0_kq + wg_j0, k0 + sg_i0_kq);
418+
}
415419
#endif
416420

417421
#if REMAINDER_K

src/plugins/intel_gpu/tests/unit/test_cases/sdpa_gpu_test.cpp

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,25 @@ struct sdpa_test_params {
3535
int sequence_length_q;
3636
int sequence_length_kv;
3737
int batch;
38+
bool dynamic;
3839
bool use_scalar_scale_val;
3940
float scale_val;
4041
bool use_scalar_attn_mask;
4142
float attn_mask_val;
4243

4344
// Constructor for basic tests (backward compatibility)
44-
sdpa_test_params(int h_size, int n_heads, int seq_q, int seq_kv, int b)
45+
sdpa_test_params(int h_size, int n_heads, int seq_q, int seq_kv, int b,
46+
bool dynamic_shape)
4547
: head_size(h_size), num_heads(n_heads), sequence_length_q(seq_q),
46-
sequence_length_kv(seq_kv), batch(b), use_scalar_scale_val(false),
47-
scale_val(1.0f), use_scalar_attn_mask(false), attn_mask_val(0.0f) {}
48+
sequence_length_kv(seq_kv), batch(b), dynamic(dynamic_shape),
49+
use_scalar_scale_val(false), scale_val(1.0f), use_scalar_attn_mask(false),
50+
attn_mask_val(0.0f) {}
4851

4952
// Constructor for advanced caching tests
5053
sdpa_test_params(int h_size, int n_heads, int seq_q, int seq_kv, int b,
5154
bool use_scale, float scale, bool use_mask, float mask)
52-
: head_size(h_size), num_heads(n_heads), sequence_length_q(seq_q),
53-
sequence_length_kv(seq_kv), batch(b), use_scalar_scale_val(use_scale),
55+
: head_size(h_size), num_heads(n_heads), sequence_length_q(seq_q), sequence_length_kv(seq_kv),
56+
batch(b), dynamic(true), use_scalar_scale_val(use_scale),
5457
scale_val(scale), use_scalar_attn_mask(use_mask), attn_mask_val(mask) {}
5558
};
5659

@@ -69,10 +72,10 @@ struct sdpa_gpu_test : public ::testing::TestWithParam<sdpa_test_params> {
6972
}
7073

7174
std::tuple<cldnn::memory::ptr, cldnn::network::ptr> run_network(bool is_caching_test, bool use_micro_sdpa,
72-
cldnn::layout input0_dyn_layout,
73-
cldnn::layout input1_dyn_layout,
74-
cldnn::layout input2_dyn_layout,
75-
cldnn::layout input3_dyn_layout,
75+
cldnn::layout input0_layout,
76+
cldnn::layout input1_layout,
77+
cldnn::layout input2_layout,
78+
cldnn::layout input3_layout,
7679
cldnn::memory::ptr input0,
7780
cldnn::memory::ptr input1,
7881
cldnn::memory::ptr input2,
@@ -83,10 +86,10 @@ struct sdpa_gpu_test : public ::testing::TestWithParam<sdpa_test_params> {
8386
float attn_mask_val = 0.0f) {
8487
auto& engine = get_test_engine();
8588
topology topo;
86-
topo.add(input_layout("input0", input0_dyn_layout));
87-
topo.add(input_layout("input1", input1_dyn_layout));
88-
topo.add(input_layout("input2", input2_dyn_layout));
89-
topo.add(input_layout("input3", input3_dyn_layout));
89+
topo.add(input_layout("input0", input0_layout));
90+
topo.add(input_layout("input1", input1_layout));
91+
topo.add(input_layout("input2", input2_layout));
92+
topo.add(input_layout("input3", input3_layout));
9093

9194
auto sdpa_prim = scaled_dot_product_attention("sdpa", {input_info("input0"), input_info("input1"), input_info("input2"), input_info("input3")},
9295
false, -1, {0,2,1,3}, {0,2,1,3}, {0,2,1,3}, {0,1,2,3}, {}, false);
@@ -137,15 +140,30 @@ struct sdpa_gpu_test : public ::testing::TestWithParam<sdpa_test_params> {
137140
const auto attn_mask_val = p.attn_mask_val;
138141

139142
auto& engine = get_test_engine();
140-
cldnn::layout input0_dyn_layout({-1, -1, num_heads, head_size}, data_types::f16, format::bfyx);
141-
cldnn::layout input1_dyn_layout({-1, -1, num_heads, head_size}, data_types::f16, format::bfyx);
142-
cldnn::layout input2_dyn_layout({-1, -1, num_heads, head_size}, data_types::f16, format::bfyx);
143-
cldnn::layout input3_dyn_layout({-1, num_heads, -1, -1}, data_types::f16, format::bfyx);
144-
145-
cldnn::layout input0_static_layout({batch, seq_length_q, num_heads, head_size}, data_types::f16, format::bfyx);
146-
cldnn::layout input1_static_layout({batch, seq_length_kv, num_heads, head_size}, data_types::f16, format::bfyx);
147-
cldnn::layout input2_static_layout({batch, seq_length_kv, num_heads, head_size}, data_types::f16, format::bfyx);
148-
cldnn::layout input3_static_layout({batch, num_heads, 1, seq_length_kv}, data_types::f16, format::bfyx);
143+
cldnn::layout input0_layout, input1_layout, input2_layout, input3_layout;
144+
cldnn::layout input0_static_layout, input1_static_layout, input2_static_layout, input3_static_layout;
145+
146+
if (p.dynamic) {
147+
input0_layout = cldnn::layout({-1, -1, num_heads, head_size}, data_types::f16, format::bfyx);
148+
input1_layout = cldnn::layout({-1, -1, num_heads, head_size}, data_types::f16, format::bfyx);
149+
input2_layout = cldnn::layout({-1, -1, num_heads, head_size}, data_types::f16, format::bfyx);
150+
input3_layout = cldnn::layout({-1, num_heads, -1, -1}, data_types::f16, format::bfyx);
151+
152+
input0_static_layout = cldnn::layout({batch, seq_length_q, num_heads, head_size}, data_types::f16, format::bfyx);
153+
input1_static_layout = cldnn::layout({batch, seq_length_kv, num_heads, head_size}, data_types::f16, format::bfyx);
154+
input2_static_layout = cldnn::layout({batch, seq_length_kv, num_heads, head_size}, data_types::f16, format::bfyx);
155+
input3_static_layout = cldnn::layout({batch, num_heads, 1, seq_length_kv}, data_types::f16, format::bfyx);
156+
} else {
157+
input0_static_layout = cldnn::layout({batch, seq_length_q, num_heads, head_size}, data_types::f16, format::bfyx);
158+
input1_static_layout = cldnn::layout({batch, seq_length_kv, num_heads, head_size}, data_types::f16, format::bfyx);
159+
input2_static_layout = cldnn::layout({batch, seq_length_kv, num_heads, head_size}, data_types::f16, format::bfyx);
160+
input3_static_layout = cldnn::layout({batch, num_heads, 1, seq_length_kv}, data_types::f16, format::bfyx);
161+
162+
input0_layout = input0_static_layout;
163+
input1_layout = input1_static_layout;
164+
input2_layout = input2_static_layout;
165+
input3_layout = input3_static_layout;
166+
}
149167

150168
auto input0 = engine.allocate_memory(input0_static_layout);
151169
auto input1 = engine.allocate_memory(input1_static_layout);
@@ -158,11 +176,11 @@ struct sdpa_gpu_test : public ::testing::TestWithParam<sdpa_test_params> {
158176
load_input(input3, 3);
159177

160178
auto [mem_ref_ptr, net_ref_ptr] = run_network(is_caching_test, false,
161-
input0_dyn_layout, input1_dyn_layout, input2_dyn_layout, input3_dyn_layout,
179+
input0_layout, input1_layout, input2_layout, input3_layout,
162180
input0, input1, input2, input3,
163181
use_scalar_scale_val, scale_val, use_scalar_attn_mask, attn_mask_val);
164182
auto [mem_opt_ptr, net_opt_ptr] = run_network(is_caching_test, true,
165-
input0_dyn_layout, input1_dyn_layout, input2_dyn_layout, input3_dyn_layout,
183+
input0_layout, input1_layout, input2_layout, input3_layout,
166184
input0, input1, input2, input3,
167185
use_scalar_scale_val, scale_val, use_scalar_attn_mask, attn_mask_val);
168186

@@ -225,6 +243,10 @@ struct sdpa_gpu_test : public ::testing::TestWithParam<sdpa_test_params> {
225243
result += "_mask_" + std::to_string(static_cast<int>(info.param.attn_mask_val * 1000));
226244
}
227245

246+
if (!info.param.dynamic) {
247+
result += "_static";
248+
}
249+
228250
return result;
229251
}
230252
};
@@ -233,7 +255,10 @@ INSTANTIATE_TEST_SUITE_P(
233255
smoke_sdpa_gpu_test,
234256
sdpa_gpu_test,
235257
::testing::Values(
236-
sdpa_test_params{64, 32, 990, 128, 2},
258+
sdpa_test_params{64, 32, 990, 128, 2, true}, // dynamic
259+
sdpa_test_params{64, 32, 990, 128, 2, false}, // static
260+
sdpa_test_params{64, 32, 990, 1, 2, true}, // dynamic
261+
sdpa_test_params{64, 32, 990, 1, 2, false}, // static
237262
sdpa_test_params{64, 32, 128, 128, 2, true, 0.125f, false, 0.0f}, // scale_val only
238263
sdpa_test_params{64, 32, 128, 128, 2, false, 1.0f, true, 0.5f} // attn_mask only
239264
),

0 commit comments

Comments
 (0)