|
20 | 20 | #include "low_precision/layer_transformation.hpp" |
21 | 21 | #include "low_precision/quantization_details.hpp" |
22 | 22 | #include "nodes/fullyconnected.h" |
| 23 | +#include "nodes/gathermatmul.h" |
23 | 24 | #include "openvino/core/descriptor/tensor.hpp" |
24 | 25 | #include "openvino/core/graph_util.hpp" |
25 | 26 | #include "openvino/core/node.hpp" |
|
70 | 71 | #include "transformations/common_optimizations/mark_precision_sensitive_shapeof_subgraphs.hpp" |
71 | 72 | #include "transformations/common_optimizations/mark_rope_input_to_keep_in_mixed_precision.hpp" |
72 | 73 | #include "transformations/common_optimizations/matmul_const_transposes_extraction.hpp" |
| 74 | +#include "transformations/common_optimizations/matmul_experts_fusion.hpp" |
73 | 75 | #include "transformations/common_optimizations/move_eltwise_up_data_movement.hpp" |
74 | 76 | #include "transformations/common_optimizations/mul_fake_quantize_fusion.hpp" |
75 | 77 | #include "transformations/common_optimizations/nop_elimination.hpp" |
|
78 | 80 | #include "transformations/common_optimizations/transpose_sinking.hpp" |
79 | 81 | #include "transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp" |
80 | 82 | #include "transformations/common_optimizations/wrap_interpolate_into_transposes.hpp" |
81 | | -#include "transformations/common_optimizations/matmul_experts_fusion.hpp" |
82 | 83 | #include "transformations/control_flow/unroll_tensor_iterator.hpp" |
83 | 84 | #include "transformations/convert_precision.hpp" |
| 85 | +#include "transformations/cpu_opset/common/op/batch_gather_matmul_compressed.hpp" |
84 | 86 | #include "transformations/fp16_compression/convert_compression_only_to_legacy.hpp" |
85 | 87 | #include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp" |
86 | 88 | #include "transformations/fp16_compression/mark_floatpoint_range.hpp" |
@@ -572,11 +574,14 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis |
572 | 574 | CPU_REGISTER_PASS_X64( |
573 | 575 | manager, |
574 | 576 | ConvertBatchGatherMatmulToBatchGatherMatmulCompressed, |
575 | | - // TODO: create separate helpers (defining supported precisions) for BatchGatherMatmul CPU node |
576 | | - ov::intel_cpu::node::FullyConnected::getSupportedCompressedActivationsTypes(), |
577 | | - ov::intel_cpu::node::FullyConnected::getSupportedCompressedWeightsTypes(), |
578 | | - // TODO: set a plugin configuration predicate when CPU node is implemented |
579 | | - nullptr); |
| 577 | + ov::intel_cpu::node::GatherMatmul::getSupportedCompressedActivationsTypes(), |
| 578 | + ov::intel_cpu::node::GatherMatmul::getSupportedCompressedWeightsTypes(), |
| 579 | + [&](const std::shared_ptr<ov::intel_cpu::BatchGatherMatmulCompressed>& gather_matmul, |
| 580 | + size_t IC, |
| 581 | + size_t OC, |
| 582 | + size_t G) { |
| 583 | + return ov::intel_cpu::node::GatherMatmul::isSupportedCompressedOperation(gather_matmul, IC, OC, G, config); |
| 584 | + }); |
580 | 585 | ov::pass::ConvertPagedAttnInputs::KVCacheConfig cacheConfig; |
581 | 586 | cacheConfig.keyCachePrecision = config.keyCachePrecision; |
582 | 587 | cacheConfig.valueCachePrecision = config.valueCachePrecision; |
|
0 commit comments