@@ -35,22 +35,25 @@ struct sdpa_test_params {
35
35
int sequence_length_q;
36
36
int sequence_length_kv;
37
37
int batch;
38
+ bool dynamic;
38
39
bool use_scalar_scale_val;
39
40
float scale_val;
40
41
bool use_scalar_attn_mask;
41
42
float attn_mask_val;
42
43
43
44
// Constructor for basic tests (backward compatibility)
44
- sdpa_test_params (int h_size, int n_heads, int seq_q, int seq_kv, int b)
45
+ sdpa_test_params (int h_size, int n_heads, int seq_q, int seq_kv, int b,
46
+ bool dynamic_shape)
45
47
: head_size(h_size), num_heads(n_heads), sequence_length_q(seq_q),
46
- sequence_length_kv (seq_kv), batch(b), use_scalar_scale_val(false ),
47
- scale_val(1 .0f ), use_scalar_attn_mask(false ), attn_mask_val(0 .0f ) {}
48
+ sequence_length_kv (seq_kv), batch(b), dynamic(dynamic_shape),
49
+ use_scalar_scale_val(false ), scale_val(1 .0f ), use_scalar_attn_mask(false ),
50
+ attn_mask_val(0 .0f ) {}
48
51
49
52
// Constructor for advanced caching tests
50
53
sdpa_test_params (int h_size, int n_heads, int seq_q, int seq_kv, int b,
51
54
bool use_scale, float scale, bool use_mask, float mask)
52
- : head_size(h_size), num_heads(n_heads), sequence_length_q(seq_q),
53
- sequence_length_kv(seq_kv ), batch(b ), use_scalar_scale_val(use_scale),
55
+ : head_size(h_size), num_heads(n_heads), sequence_length_q(seq_q), sequence_length_kv(seq_kv),
56
+ batch(b ), dynamic( true ), use_scalar_scale_val(use_scale),
54
57
scale_val(scale), use_scalar_attn_mask(use_mask), attn_mask_val(mask) {}
55
58
};
56
59
@@ -69,10 +72,10 @@ struct sdpa_gpu_test : public ::testing::TestWithParam<sdpa_test_params> {
69
72
}
70
73
71
74
std::tuple<cldnn::memory::ptr, cldnn::network::ptr> run_network (bool is_caching_test, bool use_micro_sdpa,
72
- cldnn::layout input0_dyn_layout ,
73
- cldnn::layout input1_dyn_layout ,
74
- cldnn::layout input2_dyn_layout ,
75
- cldnn::layout input3_dyn_layout ,
75
+ cldnn::layout input0_layout ,
76
+ cldnn::layout input1_layout ,
77
+ cldnn::layout input2_layout ,
78
+ cldnn::layout input3_layout ,
76
79
cldnn::memory::ptr input0,
77
80
cldnn::memory::ptr input1,
78
81
cldnn::memory::ptr input2,
@@ -83,10 +86,10 @@ struct sdpa_gpu_test : public ::testing::TestWithParam<sdpa_test_params> {
83
86
float attn_mask_val = 0 .0f ) {
84
87
auto & engine = get_test_engine ();
85
88
topology topo;
86
- topo.add (input_layout (" input0" , input0_dyn_layout ));
87
- topo.add (input_layout (" input1" , input1_dyn_layout ));
88
- topo.add (input_layout (" input2" , input2_dyn_layout ));
89
- topo.add (input_layout (" input3" , input3_dyn_layout ));
89
+ topo.add (input_layout (" input0" , input0_layout ));
90
+ topo.add (input_layout (" input1" , input1_layout ));
91
+ topo.add (input_layout (" input2" , input2_layout ));
92
+ topo.add (input_layout (" input3" , input3_layout ));
90
93
91
94
auto sdpa_prim = scaled_dot_product_attention (" sdpa" , {input_info (" input0" ), input_info (" input1" ), input_info (" input2" ), input_info (" input3" )},
92
95
false , -1 , {0 ,2 ,1 ,3 }, {0 ,2 ,1 ,3 }, {0 ,2 ,1 ,3 }, {0 ,1 ,2 ,3 }, {}, false );
@@ -137,15 +140,30 @@ struct sdpa_gpu_test : public ::testing::TestWithParam<sdpa_test_params> {
137
140
const auto attn_mask_val = p.attn_mask_val ;
138
141
139
142
auto & engine = get_test_engine ();
140
- cldnn::layout input0_dyn_layout ({-1 , -1 , num_heads, head_size}, data_types::f16 , format::bfyx);
141
- cldnn::layout input1_dyn_layout ({-1 , -1 , num_heads, head_size}, data_types::f16 , format::bfyx);
142
- cldnn::layout input2_dyn_layout ({-1 , -1 , num_heads, head_size}, data_types::f16 , format::bfyx);
143
- cldnn::layout input3_dyn_layout ({-1 , num_heads, -1 , -1 }, data_types::f16 , format::bfyx);
144
-
145
- cldnn::layout input0_static_layout ({batch, seq_length_q, num_heads, head_size}, data_types::f16 , format::bfyx);
146
- cldnn::layout input1_static_layout ({batch, seq_length_kv, num_heads, head_size}, data_types::f16 , format::bfyx);
147
- cldnn::layout input2_static_layout ({batch, seq_length_kv, num_heads, head_size}, data_types::f16 , format::bfyx);
148
- cldnn::layout input3_static_layout ({batch, num_heads, 1 , seq_length_kv}, data_types::f16 , format::bfyx);
143
+ cldnn::layout input0_layout, input1_layout, input2_layout, input3_layout;
144
+ cldnn::layout input0_static_layout, input1_static_layout, input2_static_layout, input3_static_layout;
145
+
146
+ if (p.dynamic ) {
147
+ input0_layout = cldnn::layout ({-1 , -1 , num_heads, head_size}, data_types::f16 , format::bfyx);
148
+ input1_layout = cldnn::layout ({-1 , -1 , num_heads, head_size}, data_types::f16 , format::bfyx);
149
+ input2_layout = cldnn::layout ({-1 , -1 , num_heads, head_size}, data_types::f16 , format::bfyx);
150
+ input3_layout = cldnn::layout ({-1 , num_heads, -1 , -1 }, data_types::f16 , format::bfyx);
151
+
152
+ input0_static_layout = cldnn::layout ({batch, seq_length_q, num_heads, head_size}, data_types::f16 , format::bfyx);
153
+ input1_static_layout = cldnn::layout ({batch, seq_length_kv, num_heads, head_size}, data_types::f16 , format::bfyx);
154
+ input2_static_layout = cldnn::layout ({batch, seq_length_kv, num_heads, head_size}, data_types::f16 , format::bfyx);
155
+ input3_static_layout = cldnn::layout ({batch, num_heads, 1 , seq_length_kv}, data_types::f16 , format::bfyx);
156
+ } else {
157
+ input0_static_layout = cldnn::layout ({batch, seq_length_q, num_heads, head_size}, data_types::f16 , format::bfyx);
158
+ input1_static_layout = cldnn::layout ({batch, seq_length_kv, num_heads, head_size}, data_types::f16 , format::bfyx);
159
+ input2_static_layout = cldnn::layout ({batch, seq_length_kv, num_heads, head_size}, data_types::f16 , format::bfyx);
160
+ input3_static_layout = cldnn::layout ({batch, num_heads, 1 , seq_length_kv}, data_types::f16 , format::bfyx);
161
+
162
+ input0_layout = input0_static_layout;
163
+ input1_layout = input1_static_layout;
164
+ input2_layout = input2_static_layout;
165
+ input3_layout = input3_static_layout;
166
+ }
149
167
150
168
auto input0 = engine.allocate_memory (input0_static_layout);
151
169
auto input1 = engine.allocate_memory (input1_static_layout);
@@ -158,11 +176,11 @@ struct sdpa_gpu_test : public ::testing::TestWithParam<sdpa_test_params> {
158
176
load_input (input3, 3 );
159
177
160
178
auto [mem_ref_ptr, net_ref_ptr] = run_network (is_caching_test, false ,
161
- input0_dyn_layout, input1_dyn_layout, input2_dyn_layout, input3_dyn_layout ,
179
+ input0_layout, input1_layout, input2_layout, input3_layout ,
162
180
input0, input1, input2, input3,
163
181
use_scalar_scale_val, scale_val, use_scalar_attn_mask, attn_mask_val);
164
182
auto [mem_opt_ptr, net_opt_ptr] = run_network (is_caching_test, true ,
165
- input0_dyn_layout, input1_dyn_layout, input2_dyn_layout, input3_dyn_layout ,
183
+ input0_layout, input1_layout, input2_layout, input3_layout ,
166
184
input0, input1, input2, input3,
167
185
use_scalar_scale_val, scale_val, use_scalar_attn_mask, attn_mask_val);
168
186
@@ -225,6 +243,10 @@ struct sdpa_gpu_test : public ::testing::TestWithParam<sdpa_test_params> {
225
243
result += " _mask_" + std::to_string (static_cast <int >(info.param .attn_mask_val * 1000 ));
226
244
}
227
245
246
+ if (!info.param .dynamic ) {
247
+ result += " _static" ;
248
+ }
249
+
228
250
return result;
229
251
}
230
252
};
@@ -233,7 +255,10 @@ INSTANTIATE_TEST_SUITE_P(
233
255
smoke_sdpa_gpu_test,
234
256
sdpa_gpu_test,
235
257
::testing::Values (
236
- sdpa_test_params{64 , 32 , 990 , 128 , 2 },
258
+ sdpa_test_params{64 , 32 , 990 , 128 , 2 , true }, // dynamic
259
+ sdpa_test_params{64 , 32 , 990 , 128 , 2 , false }, // static
260
+ sdpa_test_params{64 , 32 , 990 , 1 , 2 , true }, // dynamic
261
+ sdpa_test_params{64 , 32 , 990 , 1 , 2 , false }, // static
237
262
sdpa_test_params{64 , 32 , 128 , 128 , 2 , true , 0 .125f , false , 0 .0f }, // scale_val only
238
263
sdpa_test_params{64 , 32 , 128 , 128 , 2 , false , 1 .0f , true , 0 .5f } // attn_mask only
239
264
),
0 commit comments