Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,14 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
#if WITH_ATTN_MASK
/* Load mask. No remainder handling needed assuming k block size is a power of 2. */
mask_tile_type mask_tile;
const uint mask_m = MSK_D2;
const uint mask_n = MSK_D3;
// Check if attention mask has a single Query dimension (e.g., [batch, num_heads, 1, sequence_length])
// In the case of single query dimension, set ld and offset_r to zero
// to avoid exceeding bounds for single dimension.
const uint mask_ld = (mask_m == 1)? 0 : mask_n;
const uint mask_offset_r = (mask_m == 1)? 0 : sg_j0_kq + wg_j0;
tile_load_t(&mask_tile, msk, mask_m, mask_n, mask_ld, mask_offset_r, k0 + sg_i0_kq);
if (MSK_D2 == 1 && MSK_D3 > 1) {
// Check if attention mask has a single Query dimension (e.g., [batch, num_heads, 1, sequence_length])
// In the case of single query dimension, set ld and offset_r to zero
// to avoid exceeding bounds for single dimension.
tile_load_t(&mask_tile, msk, MSK_D2, MSK_D3, 0, 0, k0 + sg_i0_kq);
} else {
tile_load_t(&mask_tile, msk, q, k, sg_j0_kq + wg_j0, k0 + sg_i0_kq);
}
#endif

#if REMAINDER_K
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include "intel_gpu/plugin/program_builder.hpp"
#include "intel_gpu/plugin/common_utils.hpp"

#include "intel_gpu/op/sdpa.hpp"
#include "intel_gpu/op/indirect_sdpa.hpp"
Expand All @@ -11,6 +12,7 @@
#include "openvino/op/scaled_dot_product_attention.hpp"

#include "intel_gpu/primitives/scaled_dot_product_attention.hpp"
#include "intel_gpu/primitives/reshape.hpp"

namespace ov {
namespace op {
Expand All @@ -34,6 +36,45 @@ static std::shared_ptr<ov::op::v0::Constant> GetScalarConstInput(const std::shar
return constOp;
}

static void ReshapeInput(ProgramBuilder& p, const std::shared_ptr<ov::op::Op>& op, std::vector<cldnn::input_info>& inputs) {
if (!p.use_new_shape_infer()) {
auto layer_name = layer_type_name_ID(op);
auto output_pshape = op->get_output_partial_shape(0);
auto output_rank = output_pshape.size() < 4 ? 4 : output_pshape.size();

for (size_t idx = 0; idx < op->get_input_size(); ++idx) {
if (op->get_input_partial_shape(idx).rank().get_length() < 4) {
auto &input = inputs[idx];
auto input_pshape = op->get_input_partial_shape(idx);
auto input_rank = input_pshape.size();

auto input_shape = op->get_input_shape(idx);
input_shape.insert(input_shape.begin(), output_rank - input_rank, 1ul);

auto target_input_shape = tensor_from_dims(input_shape);
auto input_reshape_name = layer_name + "_input_" + std::to_string(idx) + "_reshape";
auto input_reshape_prim = cldnn::reshape(input_reshape_name, input, target_input_shape);
p.add_primitive(*op, input_reshape_prim);
input.pid = input_reshape_name;
}
}
}
}

static void GetNewOrder(ProgramBuilder&p, const std::shared_ptr<ov::op::internal::SDPA>& op, std::vector<std::vector<int64_t>>& transpose_orders) {
transpose_orders[0] = op->get_input0_transpose_order();
transpose_orders[1] = op->get_input1_transpose_order();
transpose_orders[2] = op->get_input2_transpose_order();
transpose_orders[3] = op->get_output_transpose_order();

if (!p.use_new_shape_infer() && op->get_input_partial_shape(0).rank().get_length() < 4) {
for (auto &order : transpose_orders) {
for (auto &dim : order) ++dim;
order.insert(order.begin(), 0);
}
}
}

static void CreateScaledDotProductAttentionOp(ProgramBuilder& p, const std::shared_ptr<ov::op::v13::ScaledDotProductAttention>& op) {
// if transpose fusion is disabled, this is used
validate_inputs_count(op, {3, 4, 5});
Expand All @@ -43,6 +84,8 @@ static void CreateScaledDotProductAttentionOp(ProgramBuilder& p, const std::shar
auto scalar_scale = GetScalarConstInput(op, scale_idx);
auto scalar_attn_mask = GetScalarConstInput(op, attn_mask_idx);

ReshapeInput(p, op, inputs);

bool is_causal = op->get_causal();
auto order = ov::op::internal::SDPA::default_order(op->get_output_partial_shape(0).size());
auto sdpa_prim = cldnn::scaled_dot_product_attention(layerName,
Expand Down Expand Up @@ -73,17 +116,22 @@ static void CreateSDPAOp(ProgramBuilder& p, const std::shared_ptr<ov::op::intern
auto scalar_scale = GetScalarConstInput(op, scale_idx);
auto scalar_attn_mask = GetScalarConstInput(op, attn_mask_idx);

ReshapeInput(p, op, inputs);

std::vector<std::vector<int64_t>> transpose_orders(4);
GetNewOrder(p, op, transpose_orders);

bool is_causal = op->get_causal();
int64_t indirect_axis = -1;

auto sdpa_prim = cldnn::scaled_dot_product_attention(layerName,
inputs,
is_causal,
indirect_axis,
op->get_input0_transpose_order(),
op->get_input1_transpose_order(),
op->get_input2_transpose_order(),
op->get_output_transpose_order());
transpose_orders[0],
transpose_orders[1],
transpose_orders[2],
transpose_orders[3]);
if (scalar_scale) {
sdpa_prim.scale_val = scalar_scale->cast_vector<float>()[0];
}
Expand All @@ -102,6 +150,11 @@ static void CreateIndirectSDPAOp(ProgramBuilder& p, const std::shared_ptr<ov::op
auto scalar_scale = GetScalarConstInput(op, scale_idx);
auto scalar_attn_mask = GetScalarConstInput(op, attn_mask_idx);

ReshapeInput(p, op, inputs);

std::vector<std::vector<int64_t>> transpose_orders(4);
GetNewOrder(p, op, transpose_orders);

bool is_causal = op->get_causal();
const auto compression_inputs = op->get_compression_inputs_num();
validate_inputs_count(op, {4 + compression_inputs, 5 + compression_inputs, 6 + compression_inputs});
Expand All @@ -111,10 +164,10 @@ static void CreateIndirectSDPAOp(ProgramBuilder& p, const std::shared_ptr<ov::op
inputs,
is_causal,
indirect_axis,
op->get_input0_transpose_order(),
op->get_input1_transpose_order(),
op->get_input2_transpose_order(),
op->get_output_transpose_order(),
transpose_orders[0],
transpose_orders[1],
transpose_orders[2],
transpose_orders[3],
op->get_quantization_attrs(),
op->get_kv_compressed());
if (scalar_scale) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,42 @@ INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttnDynamic3D_GPU,
dynamic_shape_params_3D,
ScaledAttnLayerGPUTest::getTestCaseName);

const std::vector<std::vector<InputShape>> static_shapes_3D{
// static shapes
{
// q shape
{ov::test::InputShape{ov::PartialShape{16, 128, 80},
{ov::Shape{16, 128, 80}}}
},
// k shape
{ov::test::InputShape{ov::PartialShape{16, 128, 80},
{ov::Shape{16, 128, 80}}}
},
// v shape
{ov::test::InputShape{ov::PartialShape{16, 128, 80},
{ov::Shape{16, 128, 80}}}
},
// attn shape: [B, 128, -128, L0+L1]
{ov::test::InputShape{ov::PartialShape{128, 128},
{ov::Shape{128, 128}}}
},
},
};

const auto static_shape_params_3D = testing::Combine(testing::Values(ov::element::f16),
testing::ValuesIn(static_shapes_3D),
testing::Values(true, false),
testing::Values(true, false),
testing::Values(true, false),
testing::Values(true, false),
testing::Values(true, false),
testing::ValuesIn({disable_transpose, transpose_all_3D}));

INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttnStatic3D_GPU,
ScaledAttnLayerGPUTest,
static_shape_params_3D,
ScaledAttnLayerGPUTest::getTestCaseName);

const std::vector<std::vector<InputShape>> dynamic_shapes_4D {
// normal case, shapes of q,k,v are same
{
Expand Down Expand Up @@ -549,6 +585,24 @@ const std::vector<std::vector<InputShape>> dynamic_shapes_4D {
{ov::test::InputShape{ov::PartialShape{-1, -1},
{ov::Shape{245, 245}, ov::Shape{1, 1}, ov::Shape{10, 10}}}
},
},
{
// q shape
{ov::test::InputShape{ov::PartialShape{-1, 10, -1, 64},
{ov::Shape{1, 10, 77, 64}}}
},
// k shape
{ov::test::InputShape{ov::PartialShape{-1, 10, -1, 64},
{ov::Shape{1, 10, 77, 64}}}
},
// v shape
{ov::test::InputShape{ov::PartialShape{-1, 10, -1, 64},
{ov::Shape{1, 10, 77, 64}}}
},
// attn shape: [B, 1, -1, L0+L1]
{ov::test::InputShape{ov::PartialShape{77, 77},
{ov::Shape{77, 77}}}
},
}
};

Expand Down Expand Up @@ -643,6 +697,24 @@ const std::vector<std::vector<InputShape>> static_shapes{
{ov::Shape{1, 1, 1, 100}}}
},
},
{
// q shape
{ov::test::InputShape{ov::PartialShape{1, 10, 77, 64},
{ov::Shape{1, 10, 77, 64}}}
},
// k shape
{ov::test::InputShape{ov::PartialShape{1, 10, 77, 64},
{ov::Shape{1, 10, 77, 64}}}
},
// v shape
{ov::test::InputShape{ov::PartialShape{1, 10, 77, 64},
{ov::Shape{1, 10, 77, 64}}}
},
// attn shape: [B, 1, -1, L0+L1]
{ov::test::InputShape{ov::PartialShape{77, 77},
{ov::Shape{77, 77}}}
},
},
};

const auto static_shape_params = testing::Combine(testing::Values(ov::element::f16),
Expand Down
10 changes: 10 additions & 0 deletions src/plugins/intel_gpu/tests/functional/subgraph_tests/sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@ INSTANTIATE_TEST_SUITE_P(SDPAFusionTests,
1.0f,
0.025f,
0.025f),
std::make_tuple(ov::PartialShape{1, 10, 77, 64},
ov::Shape{10, 77, 64},
ov::PartialShape{1, 10, 77, 64},
ov::Shape{10, 77, 64},
ov::PartialShape{1, 10, 77, 64},
ov::Shape{10, 77, 64},
ov::PartialShape{77, 77},
1.0f,
0.025f,
0.025f),
std::make_tuple(ov::PartialShape{1, 10, 1024, 64},
ov::Shape{1, 10, 1024, 64},
ov::PartialShape{1, 10, 1024, 64},
Expand Down
Loading
Loading