diff --git a/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp b/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp index 1cf2e6d0230a9a..88e8905e37b630 100644 --- a/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp +++ b/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp @@ -210,10 +210,23 @@ bool ov::pass::SymbolicOptimizations::run_on_model(const std::shared_ptrget_pass_config(); + + const auto old_pass_config = *pass_config; + + // Temporarily disable passes for SymbolicOptimizations execution pass_config->disable(); pass_config->disable(); - m_manager->run_passes(m); + bool result; + try { + result = m_manager->run_passes(m); + } catch (...) { + *pass_config = old_pass_config; // Restore original pass config on exception + throw; // Re-throw the exception + } + + *pass_config = old_pass_config; + ov::remove_skip_invalidation_rti(m); - return true; + return result; } diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp index 9f3c2df7fb8eed..4eb8665df825ee 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp @@ -10,10 +10,6 @@ #include #include #include -#include -#include -#include -#include #include #include @@ -23,7 +19,9 @@ #include "openvino/core/node.hpp" #include "openvino/core/node_output.hpp" #include "openvino/core/node_vector.hpp" +#include "openvino/core/rt_info.hpp" #include "openvino/core/type.hpp" +#include "openvino/core/type/element_type.hpp" #include "openvino/op/assign.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" @@ -36,9 +34,12 @@ #include "openvino/pass/matcher_pass.hpp" #include "openvino/pass/pattern/matcher.hpp" #include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/common_optimizations/simplify_shape_of_sub_graph.hpp" #include "transformations/cpu_opset/common/op/sdpa.hpp" #include "transformations/defs.hpp" +#include "transformations/symbolic_transformations/symbolic_optimizations.hpp" #include "transformations/transpose_sinking/ts_shape_of.hpp" #include "transformations/utils/utils.hpp" @@ -46,7 +47,6 @@ # include "transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.hpp" #endif -using namespace ov::gen_pattern; using namespace ov::pass; namespace ov::intel_cpu { @@ -55,28 +55,24 @@ StatefulSDPAFusion::StatefulSDPAFusion() { MATCHER_SCOPE(StatefulSDPAFusion); using namespace ov::pass::pattern; - auto beam_idx = makePattern("i32[?]"); + auto beam_idx = any_input(type_matches(element::i32) && shape_matches("[?]")); auto cur_q = any_input(); auto cur_k = any_input(); auto cur_v = any_input(); - auto axis_seq_len = ov::gen_pattern::Symbol("axis_seq_len"); - auto axis_beam = ov::gen_pattern::Symbol("axis_beam"); + auto past_k = wrap_type(); + auto past_v = wrap_type(); - // past_kv can be BHLS/LBHS - auto past_k = makePattern({}); - auto past_v = makePattern({}); - - auto convert_past_k = makePattern({past_k}); - auto convert_past_v = makePattern({past_v}); + auto convert_past_k = wrap_type({past_k}); + auto convert_past_v = wrap_type({past_v}); auto gather_input_k = - makePattern({past_k | convert_past_k, beam_idx, axis_beam}, {{"batch_dims", 0}}); + wrap_type({past_k | convert_past_k, beam_idx, "axis_beam"}, {{"batch_dims", 0}}); auto gather_input_v = - makePattern({past_v | convert_past_v, beam_idx, axis_beam}, {{"batch_dims", 0}}); + wrap_type({past_v | convert_past_v, beam_idx, "axis_beam"}, {{"batch_dims", 0}}); - auto concat_k = makePattern({gather_input_k, cur_k}, {{"axis", axis_seq_len}}); - auto concat_v = makePattern({gather_input_v, cur_v}, {{"axis", axis_seq_len}}); + auto concat_k = wrap_type({gather_input_k, cur_k}); + auto concat_v = wrap_type({gather_input_v, cur_v}); std::shared_ptr mq_reshape_k, reshape_k, unsqueeze_k, computed_bcst_k, multiply_k, computed_bcst3_k; std::tie(mq_reshape_k, reshape_k, unsqueeze_k, computed_bcst_k, multiply_k, computed_bcst3_k) = @@ -88,23 +84,23 @@ StatefulSDPAFusion::StatefulSDPAFusion() { auto present_v = concat_v | mq_reshape_v; // canonical q/k/v shape definition: [B,H,...L,S] - auto sdp0 = makePattern({cur_q, present_k, present_v}); - auto sdp1 = makePattern({cur_q, present_k, present_v, any_input()}); + auto sdp0 = wrap_type({cur_q, present_k, present_v}); + auto sdp1 = wrap_type({cur_q, present_k, present_v, any_input()}); auto sdp2 = - makePattern({cur_q, present_k, present_v, any_input(), any_input()}); + wrap_type({cur_q, present_k, present_v, any_input(), any_input()}); // non-canonical q/k/v shape definitions, for example: [L, B, H, S]/[B, L, H, S] auto order_k = wrap_type(); auto order_v = wrap_type(); auto order_q = wrap_type(); - auto transpose_q = makePattern({cur_q, order_q}); - auto transpose_k = makePattern({present_k, order_k}); - auto transpose_v = makePattern({present_v, order_v}); + auto transpose_q = wrap_type({cur_q, order_q}); + auto transpose_k = wrap_type({present_k, order_k}); + auto transpose_v = wrap_type({present_v, order_v}); - auto sdp_trans0 = makePattern({transpose_q, transpose_k, transpose_v}); + auto sdp_trans0 = wrap_type({transpose_q, transpose_k, transpose_v}); auto sdp_trans1 = - makePattern({transpose_q, transpose_k, transpose_v, any_input()}); - auto sdp_trans2 = makePattern( + wrap_type({transpose_q, transpose_k, transpose_v, any_input()}); + auto sdp_trans2 = wrap_type( {transpose_q, transpose_k, transpose_v, any_input(), any_input()}); auto sdp = sdp0 | sdp1 | sdp2 | sdp_trans0 | sdp_trans1 | sdp_trans2; @@ -112,8 +108,11 @@ StatefulSDPAFusion::StatefulSDPAFusion() { ov::matcher_pass_callback callback = [=](Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); auto root = m.get_match_root(); - PatternValidator validator(m); - if (!validator) { + + // Check concat axes equality first + const auto concat_k_node = ov::as_type_ptr(pattern_map.at(concat_k).get_node_shared_ptr()); + const auto concat_v_node = ov::as_type_ptr(pattern_map.at(concat_v).get_node_shared_ptr()); + if (concat_k_node->get_axis() != concat_v_node->get_axis()) { return false; } @@ -155,9 +154,6 @@ StatefulSDPAFusion::StatefulSDPAFusion() { if (!check_valid_children_type(past_k_node) || !check_valid_children_type(past_v_node)) { return false; } - const auto concat_k_node = ov::as_type_ptr(pattern_map.at(concat_k).get_node_shared_ptr()); - const auto concat_v_node = ov::as_type_ptr(pattern_map.at(concat_v).get_node_shared_ptr()); - for (auto&& item : {concat_k_node, concat_v_node}) { auto&& children = item->get_output_target_inputs(0); switch (children.size()) { @@ -301,16 +297,16 @@ StatefulSDPAFusion::StatefulSDPAFusion() { bool SDPASubgraphFusion::run_on_model(const std::shared_ptr& f) { RUN_ON_FUNCTION_SCOPE(SDPASubgraphFusion); - ov::pass::Manager manager("SDPASubgraphFusion"); + ov::pass::SymbolicOptimizations symbolic_optimizations(false, get_pass_config()); + auto& ctx_manager = *symbolic_optimizations.get_manager(); - CPU_REGISTER_PASS_COMMON(manager, ov::pass::SimplifyGatherShapeOf); - CPU_REGISTER_PASS_COMMON(manager, ov::pass::transpose_sinking::TSShapeOfForward); - CPU_REGISTER_PASS_COMMON(manager, StatefulSDPAFusion); + CPU_REGISTER_PASS_COMMON(ctx_manager, ov::pass::SimplifyGatherShapeOf); + CPU_REGISTER_PASS_COMMON(ctx_manager, ov::pass::transpose_sinking::TSShapeOfForward); + CPU_REGISTER_PASS_COMMON(ctx_manager, StatefulSDPAFusion); // TODO: remove the following after snippets support patterns with dynamic shapes - CPU_REGISTER_PASS_X64(manager, ov::intel_cpu::SDPAFuseTransposeReshape); + CPU_REGISTER_PASS_X64(ctx_manager, ov::intel_cpu::SDPAFuseTransposeReshape); - manager.run_passes(f); - return false; + return symbolic_optimizations.run_on_model(f); } } // namespace ov::intel_cpu