Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "memory_desc/dnnl_blocked_memory_desc.h"
#include "memory_desc/dnnl_memory_desc.h"
#include "nodes/executors/dnnl/dnnl_aliases.hpp"
#include "nodes/executors/dnnl/dnnl_fullyconnected_primitive.hpp"
#include "nodes/executors/dnnl/dnnl_shape_agnostic_data.hpp"
#include "nodes/executors/executor.hpp"
#include "nodes/executors/fullyconnected_config.hpp"
Expand Down Expand Up @@ -306,7 +307,8 @@ static primitive_desc createPrimitiveDesc(const dnnl::memory::desc& inputDesc,
[[maybe_unused]] const bool useSparseWeights,
const bool useWeightsDecompression,
const bool fcSemantic) {
if (defaultImplType == impl_desc_type::undef) {
// priority-based implementation selection if implementation type is not specified
if (defaultImplType == impl_desc_type::undef && !fcSemantic) {
struct PrimitiveDescWithPriority {
dnnl::primitive_desc prim_desc;
size_t priority = 0UL;
Expand Down Expand Up @@ -484,28 +486,29 @@ bool DnnlMatMulPrimitive::useWeightsDecompressionImpl(const ov::element::Type in
return (any_of(inputType, f32, bf16, f16) && any_of(weightsType, u8, i8, u4, i4));
}

DnnlShapeAgnosticDataPtr DnnlMatMulPrimitive::createShapeAgnosticData(const FCAttrs& fcAttrs,
const MemoryArgs& memory,
const ExecutorContext::CPtr& context,
const bool cacheWeights) {
MatMulAttrs attrs;
attrs.postOps = fcAttrs.postOps;
attrs.weightsNonTransposed = fcAttrs.weightsNonTransposed;

return createShapeAgnosticData(attrs, memory, context, cacheWeights);
}

DnnlShapeAgnosticDataPtr DnnlMatMulPrimitive::createShapeAgnosticData(const MatMulAttrs& attrs,
const MemoryArgs& memory,
const ExecutorContext::CPtr& context,
const bool cacheWeights) {
if (attrs.fcSemantic) {
FCAttrs fcAttrs{attrs.withBias,
attrs.weightsNonTransposed,
false,
0,
true,
ov::intel_cpu::Config::ModelType::Unknown,
attrs.postOps};
return DnnlFCPrimitive::createShapeAgnosticData(fcAttrs, memory, context, cacheWeights);
}

DEBUG_LOG("Creating shape agnostic data");
auto srcDesc = memory.at(ARG_SRC)->getDescPtr();
auto weiDesc = memory.at(ARG_WEI)->getDescPtr();
auto dstDesc = memory.at(ARG_DST)->getDescPtr();
const auto& biasDesc = memory.at(ARG_BIAS)->getDescPtr();

const auto useWeightsDecompression = useWeightsDecompressionImpl(srcDesc->getPrecision(), weiDesc->getPrecision());

const auto postOpData =
createPrimitiveAttrs(attrs, memory, context, useWeightsDecompression, attrs.weightsNonTransposed);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,6 @@ class DnnlMatMulPrimitive {
const ExecutorContext::CPtr& context,
bool cacheWeights);

static DnnlShapeAgnosticDataPtr createShapeAgnosticData(const FCAttrs& fcAttrs,
const MemoryArgs& memory,
const ExecutorContext::CPtr& context,
bool cacheWeights);

static DnnlMemoryDescPtr makeTransposedWeightDescriptor(const DnnlMemoryDescPtr& srcDesc,
const DnnlMemoryDescPtr& dstDesc,
const MatMulAttrs& attrs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,18 @@ const std::vector<ExecutorImplementation<FCAttrs>>& getImplementations() {
[](const FCAttrs& attrs,
const MemoryArgs& memory,
const ExecutorContext::CPtr& context) -> ExecutorPtr {
MatMulAttrs matMulAttrs{false,
false};
matMulAttrs.postOps = attrs.postOps;
matMulAttrs.transposeB = attrs.weightsNonTransposed;
matMulAttrs.constantWeights = true;
MatMulAttrs matMulAttrs {
false,
false,
attrs.withBias,
attrs.weightsNonTransposed,
attrs.sparseWeights,
true,
true,
attrs.dynamicQuantizationGroupSize,
{},
attrs.postOps
};

return std::make_shared<
DnnlExecutor<DnnlMatMulPrimitive, MatMulAttrs, DnnlShapeAgnosticData,
Expand Down
Loading