Skip to content

Commit 461df5a

Browse files
committed
Apply transformations ToDo
1 parent f7d654e commit 461df5a

File tree

4 files changed

+20
-6
lines changed

4 files changed

+20
-6
lines changed

src/plugins/intel_cpu/src/nodes/gathermatmul.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,13 @@ ov::element::TypeVector GatherMatmul::getSupportedCompressedWeightsTypes([[maybe
315315
#endif
316316
}
317317

318+
ov::element::TypeVector GatherMatmul::getSupportedCompressedActivationsTypes() {
319+
using ov::element::Type_t;
320+
// @todo enable for bf16 as well
321+
// after EnforceInferencePrecision is replaced with ConvertPrecision
322+
return {Type_t::f32};
323+
}
324+
318325
GatherMatmul::GatherMatmul(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context)
319326
: Node(op, context, GatherMatmulShapeInferFactory(op)) {
320327
std::string errorMessage;

src/plugins/intel_cpu/src/nodes/gathermatmul.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class GatherMatmul : public Node {
3939
size_t G,
4040
const Config& config) noexcept;
4141
static ov::element::TypeVector getSupportedCompressedWeightsTypes(bool apply_fp8 = false);
42+
static ov::element::TypeVector getSupportedCompressedActivationsTypes();
4243

4344
private:
4445
enum class Algorithm : uint8_t { GatherMatmulDefault, GatherMatmulCompressed };

src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "low_precision/layer_transformation.hpp"
2121
#include "low_precision/quantization_details.hpp"
2222
#include "nodes/fullyconnected.h"
23+
#include "nodes/gathermatmul.h"
2324
#include "openvino/core/descriptor/tensor.hpp"
2425
#include "openvino/core/graph_util.hpp"
2526
#include "openvino/core/node.hpp"
@@ -70,6 +71,7 @@
7071
#include "transformations/common_optimizations/mark_precision_sensitive_shapeof_subgraphs.hpp"
7172
#include "transformations/common_optimizations/mark_rope_input_to_keep_in_mixed_precision.hpp"
7273
#include "transformations/common_optimizations/matmul_const_transposes_extraction.hpp"
74+
#include "transformations/common_optimizations/matmul_experts_fusion.hpp"
7375
#include "transformations/common_optimizations/move_eltwise_up_data_movement.hpp"
7476
#include "transformations/common_optimizations/mul_fake_quantize_fusion.hpp"
7577
#include "transformations/common_optimizations/nop_elimination.hpp"
@@ -78,9 +80,9 @@
7880
#include "transformations/common_optimizations/transpose_sinking.hpp"
7981
#include "transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp"
8082
#include "transformations/common_optimizations/wrap_interpolate_into_transposes.hpp"
81-
#include "transformations/common_optimizations/matmul_experts_fusion.hpp"
8283
#include "transformations/control_flow/unroll_tensor_iterator.hpp"
8384
#include "transformations/convert_precision.hpp"
85+
#include "transformations/cpu_opset/common/op/batch_gather_matmul_compressed.hpp"
8486
#include "transformations/fp16_compression/convert_compression_only_to_legacy.hpp"
8587
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
8688
#include "transformations/fp16_compression/mark_floatpoint_range.hpp"
@@ -572,11 +574,14 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
572574
CPU_REGISTER_PASS_X64(
573575
manager,
574576
ConvertBatchGatherMatmulToBatchGatherMatmulCompressed,
575-
// TODO: create separate helpers (defining supported precisions) for BatchGatherMatmul CPU node
576-
ov::intel_cpu::node::FullyConnected::getSupportedCompressedActivationsTypes(),
577-
ov::intel_cpu::node::FullyConnected::getSupportedCompressedWeightsTypes(),
578-
// TODO: set a plugin configuration predicate when CPU node is implemented
579-
nullptr);
577+
ov::intel_cpu::node::GatherMatmul::getSupportedCompressedActivationsTypes(),
578+
ov::intel_cpu::node::GatherMatmul::getSupportedCompressedWeightsTypes(),
579+
[&](const std::shared_ptr<ov::intel_cpu::BatchGatherMatmulCompressed>& gather_matmul,
580+
size_t IC,
581+
size_t OC,
582+
size_t G) {
583+
return ov::intel_cpu::node::GatherMatmul::isSupportedCompressedOperation(gather_matmul, IC, OC, G, config);
584+
});
580585
ov::pass::ConvertPagedAttnInputs::KVCacheConfig cacheConfig;
581586
cacheConfig.keyCachePrecision = config.keyCachePrecision;
582587
cacheConfig.valueCachePrecision = config.valueCachePrecision;

src/tests/functional/plugin/shared/include/shared_test_classes/subgraph/weights_decompression_builders.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <memory>
99
#include "openvino/core/node.hpp"
1010
#include "shared_test_classes/base/ov_subgraph.hpp"
11+
#include <optional>
1112

1213
namespace ov {
1314
namespace test {

0 commit comments

Comments
 (0)