Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/common/sdpa_test_iface.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright 2026 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#ifndef COMMON_SDPA_TEST_IFACE_HPP
#define COMMON_SDPA_TEST_IFACE_HPP

#include "oneapi/dnnl/dnnl_types.h"

/// Creates a primitive descriptor for a scaled dot product attention primitive
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param query_desc Query memory descriptor (tensor Q)
/// @param key_desc Key memory descriptor (tensor K)
/// @param value_desc Value memory descriptor (tensor V)
/// @param dst_desc Destination memory descriptor.
/// @param attn_mask_desc Attention mask memory descriptor.
/// @param attr Primitive attributes (can be NULL).
/// @param kq_attr Attribute for the Key/Query matmul operation(can be NULL).
/// @param vs_attr Attribute for the Value/Score matmul operation(can be NULL).
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.

dnnl_status_t DNNL_API sdpa_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine,
const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc,
const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc,
const_dnnl_memory_desc_t mask_desc, const_dnnl_memory_desc_t scale_desc,
bool invert_scale, dnnl_dim_t kv_head_number, int attn_mask_type,
dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr,
const_dnnl_primitive_attr_t kq_attr,
const_dnnl_primitive_attr_t vs_attr);

#endif
215 changes: 38 additions & 177 deletions src/graph/backend/dnnl/executables/sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "graph/backend/dnnl/executables/sdpa.hpp"

#include "common/sdpa_test_iface.hpp"

namespace dnnl {
namespace impl {
namespace graph {
Expand Down Expand Up @@ -73,134 +75,48 @@ sdpa_executable_t::sdpa_executable_t(std::shared_ptr<op_t> &op,
const alg_kind_t softmax_alg = softmax_mode == "inf_as_zero"
? alg_kind::softmax_accurate_inf_as_zero
: alg_kind::softmax_accurate;
status_t s = create_sdpa_pd(sdpa_pd_, p_engine.get(), md_q.get(),

auto ret = sdpa_primitive_desc_create(&pd_, p_engine.get(), md_q.get(),
md_k.get(), md_v.get(), md_dst.get(), md_mask.get(), md_scale.get(),
is_invert_scale_, kv_head_number, mask_type_, softmax_alg,
attr.get(), qk_attr.get(), vs_attr.get());
if (s != dnnl::impl::status::success) {
is_invert_scale_, kv_head_number, mask_type_,
static_cast<dnnl_alg_kind_t>(softmax_alg), attr.get(),
qk_attr.get(), vs_attr.get());

if (ret != dnnl_success) {
is_initialized_ = false;
} else {
status_t s = sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get());
is_initialized_ = s == status::success ? true : false;
ret = dnnl_primitive_create(&prim_, pd_);
is_initialized_ = ret == dnnl_success ? true : false;
}
}

sdpa_executable_t::~sdpa_executable_t() {
if (prim_) dnnl_primitive_destroy(prim_);
if (pd_) dnnl_primitive_desc_destroy(pd_);
}

void sdpa_executable_t::execute(const stream &stream,
const std::unordered_map<int, memory> &args) const {
exec_args_t exec_args;
memory_arg_t mem_arg_q = {(args.at(DNNL_ARG_QUERIES)).get(), true};
memory_arg_t mem_arg_k = {(args.at(DNNL_ARG_KEYS)).get(), true};
memory_arg_t mem_arg_v = {(args.at(DNNL_ARG_VALUES)).get(), true};
memory_arg_t mem_arg_dst = {(args.at(DNNL_ARG_DST)).get(), false};
memory_arg_t mem_arg_scale
= {with_scale_ ? (args.at(DNNL_ARG_SCALE)).get() : nullptr, true};
memory_arg_t mem_arg_mask = {
with_explicit_mask_ ? (args.at(DNNL_ARG_ATTN_MASK)).get() : nullptr,
true};
memory_arg_t mem_arg_k_scale = {
args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS) != args.end()
? (args.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS)).get()
: nullptr,
true};

memory_arg_t mem_arg_v_scale = {
args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES) != args.end()
? (args.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES)).get()
: nullptr,
true};
memory_arg_t mem_arg_k_zero_points = {
args.find(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS) != args.end()
? (args.at(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS)).get()
: nullptr,
true};
memory_arg_t mem_arg_v_zero_points = {
args.find(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES) != args.end()
? (args.at(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES))
.get()
: nullptr,
true};

exec_args[DNNL_ARG_QUERIES] = mem_arg_q;
exec_args[DNNL_ARG_KEYS] = mem_arg_k;
exec_args[DNNL_ARG_VALUES] = mem_arg_v;
exec_args[DNNL_ARG_DST] = mem_arg_dst;
exec_args[DNNL_ARG_SCALE] = mem_arg_scale;
exec_args[DNNL_ARG_ATTN_MASK] = mem_arg_mask;
exec_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS] = mem_arg_k_scale;
exec_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES] = mem_arg_v_scale;
exec_args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS]
= mem_arg_k_zero_points;
exec_args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES]
= mem_arg_v_zero_points;

exec_ctx_t ctx(stream.get(), std::move(exec_args));
sdpa_prim_->execute(ctx);
UNUSED(stream);
UNUSED(args);
assert(!"sdpa_executable_t::execute() is not implemented on cpu");
}

#ifdef DNNL_WITH_SYCL
::sycl::event sdpa_executable_t::execute_sycl(const stream &stream,
const std::unordered_map<int, memory> &args,
const std::vector<::sycl::event> &deps) const {
std::vector<dnnl_exec_arg_t> c_args;
c_args.reserve(args.size());
for (const auto &a : args)
c_args.push_back({a.first, a.second.get()});

exec_args_t exec_args;
memory_arg_t mem_arg_q = {(args.at(DNNL_ARG_QUERIES)).get(), true};
memory_arg_t mem_arg_k = {(args.at(DNNL_ARG_KEYS)).get(), true};
memory_arg_t mem_arg_v = {(args.at(DNNL_ARG_VALUES)).get(), true};
memory_arg_t mem_arg_dst = {(args.at(DNNL_ARG_DST)).get(), false};
memory_arg_t mem_arg_scale
= {with_scale_ ? (args.at(DNNL_ARG_SCALE)).get() : nullptr, true};
memory_arg_t mem_arg_mask = {
with_explicit_mask_ ? (args.at(DNNL_ARG_ATTN_MASK)).get() : nullptr,
true};
memory_arg_t mem_arg_k_scale = {
args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS) != args.end()
? (args.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS)).get()
: nullptr,
true};

memory_arg_t mem_arg_v_scale = {
args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES) != args.end()
? (args.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES)).get()
: nullptr,
true};
memory_arg_t mem_arg_k_zero_points = {
args.find(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS) != args.end()
? (args.at(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS)).get()
: nullptr,
true};
memory_arg_t mem_arg_v_zero_points = {
args.find(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES) != args.end()
? (args.at(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES))
.get()
: nullptr,
true};

exec_args[DNNL_ARG_QUERIES] = mem_arg_q;
exec_args[DNNL_ARG_KEYS] = mem_arg_k;
exec_args[DNNL_ARG_VALUES] = mem_arg_v;
exec_args[DNNL_ARG_DST] = mem_arg_dst;
exec_args[DNNL_ARG_SCALE] = mem_arg_scale;
exec_args[DNNL_ARG_ATTN_MASK] = mem_arg_mask;
exec_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS] = mem_arg_k_scale;
exec_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES] = mem_arg_v_scale;
exec_args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS]
= mem_arg_k_zero_points;
exec_args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES]
= mem_arg_v_zero_points;

auto strm_t = stream.get();
exec_ctx_t ctx(strm_t, std::move(exec_args));
auto *sycl_stream_impl = dnnl::impl::utils::downcast<
dnnl::impl::xpu::sycl::stream_impl_t *>(strm_t->impl());
sycl::event return_event;
auto ret = dnnl_sycl_interop_primitive_execute(prim_, stream.get(),
c_args.size(), c_args.data(), &deps, &return_event);
dnnl::error::wrap_c_api(
ret, "could not execute sdpa primitive with sycl runtime");

strm_t->before_exec_hook();

if (!deps.empty()) sycl_stream_impl->sycl_ctx().set_deps(deps);

sdpa_prim_->execute(ctx);

::sycl::event return_event = sycl_stream_impl->get_output_event();
strm_t->after_exec_hook();
return return_event;
}
#endif
Expand All @@ -209,75 +125,20 @@ ::sycl::event sdpa_executable_t::execute_sycl(const stream &stream,
cl_event sdpa_executable_t::execute_ocl(const stream &stream,
const std::unordered_map<int, memory> &args,
const std::vector<cl_event> &deps) const {
exec_args_t exec_args;
memory_arg_t mem_arg_q = {(args.at(DNNL_ARG_QUERIES)).get(), true};
memory_arg_t mem_arg_k = {(args.at(DNNL_ARG_KEYS)).get(), true};
memory_arg_t mem_arg_v = {(args.at(DNNL_ARG_VALUES)).get(), true};
memory_arg_t mem_arg_dst = {(args.at(DNNL_ARG_DST)).get(), false};
memory_arg_t mem_arg_scale
= {with_scale_ ? (args.at(DNNL_ARG_SCALE)).get() : nullptr, true};
memory_arg_t mem_arg_mask = {
with_explicit_mask_ ? (args.at(DNNL_ARG_ATTN_MASK)).get() : nullptr,
true};
memory_arg_t mem_arg_k_scale = {
args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS) != args.end()
? (args.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS)).get()
: nullptr,
true};

memory_arg_t mem_arg_v_scale = {
args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES) != args.end()
? (args.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES)).get()
: nullptr,
true};
memory_arg_t mem_arg_k_zero_points = {
args.find(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS) != args.end()
? (args.at(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS)).get()
: nullptr,
true};
memory_arg_t mem_arg_v_zero_points = {
args.find(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES) != args.end()
? (args.at(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES))
.get()
: nullptr,
true};

exec_args[DNNL_ARG_QUERIES] = mem_arg_q;
exec_args[DNNL_ARG_KEYS] = mem_arg_k;
exec_args[DNNL_ARG_VALUES] = mem_arg_v;
exec_args[DNNL_ARG_DST] = mem_arg_dst;
exec_args[DNNL_ARG_SCALE] = mem_arg_scale;
exec_args[DNNL_ARG_ATTN_MASK] = mem_arg_mask;
exec_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS] = mem_arg_k_scale;
exec_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES] = mem_arg_v_scale;
exec_args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS]
= mem_arg_k_zero_points;
exec_args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES]
= mem_arg_v_zero_points;

exec_ctx_t ctx(stream.get(), std::move(exec_args));
std::vector<dnnl_exec_arg_t> c_args;
c_args.reserve(args.size());
for (const auto &a : args)
c_args.push_back({a.first, a.second.get()});

auto *ocl_stream = dnnl::impl::utils::downcast<gpu::intel::ocl::stream_t *>(
stream.get());

ocl_stream->before_exec_hook();

if (!deps.empty()) {
std::vector<xpu::ocl::wrapper_t<cl_event>> events(deps.size());
for (size_t i = 0; i < deps.size(); i++)
events[i] = xpu::ocl::wrapper_t<cl_event>(deps[i], true);
ocl_stream->ocl_ctx().set_deps(events);
}

sdpa_prim_->execute(ctx);
const cl_event *c_deps = deps.empty() ? nullptr : deps.data();

cl_event return_event = nullptr;
if ((ocl_stream->flags() & stream_flags::in_order) == 0) {
auto last = ocl_stream->get_output_event();
return_event = last.release();
}
auto ret = dnnl_ocl_interop_primitive_execute(prim_, stream.get(),
static_cast<int>(c_args.size()), c_args.data(), c_deps,
static_cast<int>(deps.size()), &return_event);
dnnl::error::wrap_c_api(
ret, "could not execute sdpa primitive with ocl runtime");

ocl_stream->after_exec_hook();
return return_event;
}
#endif
Expand Down
6 changes: 4 additions & 2 deletions src/graph/backend/dnnl/executables/sdpa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct sdpa_executable_t : public op_executable_t {
pd_cache_t &pd_cache, const fpmath_t &fpmath,
bool use_block_layout);

~sdpa_executable_t() override;

bool is_initialized() const { return is_initialized_; }

void execute(const stream &stream,
Expand All @@ -51,8 +53,8 @@ struct sdpa_executable_t : public op_executable_t {
#endif

private:
std::shared_ptr<primitive_desc_t> sdpa_pd_;
std::shared_ptr<primitive_t> sdpa_prim_;
dnnl_primitive_desc_t pd_ = nullptr;
dnnl_primitive_t prim_ = nullptr;
bool with_scale_;
bool with_explicit_mask_;
attn_mask_type_t mask_type_;
Expand Down
27 changes: 2 additions & 25 deletions tests/gtests/internals/sdpa_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,9 @@

#include "dnnl.hpp"

// NOLINTBEGIN(readability-identifier-naming)

/// Creates a primitive descriptor for a scaled dot product attention primitive
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param query_desc Query memory descriptor (tensor Q)
/// @param key_desc Key memory descriptor (tensor K)
/// @param value_desc Value memory descriptor (tensor V)
/// @param dst_desc Destination memory descriptor.
/// @param attn_mask_desc Attention mask memory descriptor.
/// @param attr Primitive attributes (can be NULL).
/// @param kq_attr Attribute for the Key/Query matmul operation(can be NULL).
/// @param vs_attr Attribute for the Value/Score matmul operation(can be NULL).
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
#include "common/sdpa_test_iface.hpp"

dnnl_status_t DNNL_API sdpa_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine,
const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc,
const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc,
const_dnnl_memory_desc_t mask_desc, const_dnnl_memory_desc_t scale_desc,
bool invert_scale, dnnl_dim_t kv_head_number, int attn_mask_type,
dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr,
const_dnnl_primitive_attr_t kq_attr,
const_dnnl_primitive_attr_t vs_attr);
// NOLINTBEGIN(readability-identifier-naming)

namespace dnnl {
namespace impl {
Expand Down
Loading