Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0dcd42c
rewrite StatefulSDPAFusion transformation
evkotov Jul 18, 2025
5d1e51f
code review fix: check axis equality
evkotov Jul 21, 2025
b47854d
code review fix
evkotov Jul 25, 2025
f24a4d9
code review fixes
evkotov Jul 28, 2025
07a2fe0
code review fixes
evkotov Jul 29, 2025
0382e86
build fixes
evkotov Jul 30, 2025
2f7e852
code review fixes
evkotov Jul 30, 2025
1048efa
use const auto& ctx_manager
evkotov Jul 31, 2025
7f80076
clang fixes
evkotov Jul 31, 2025
3b805c2
fix for ov_tensorflow_frontend_tests CompileModelsTests.NgramCompilation
evkotov Aug 1, 2025
973660b
code review fix
evkotov Aug 4, 2025
e632e65
code review fixes
evkotov Aug 4, 2025
27cfca8
clang fixes
evkotov Aug 4, 2025
e333df1
clang fixes
evkotov Aug 4, 2025
8bf9d16
restore
evkotov Aug 4, 2025
159eae0
fix symbolic optimizations
evkotov Aug 5, 2025
2555689
cleanup
evkotov Aug 5, 2025
b3ca0d9
cleanup
evkotov Aug 5, 2025
af39142
cleanup
evkotov Aug 5, 2025
1bd46af
clang
evkotov Aug 6, 2025
5518300
move transformations to main pipeline
evkotov Aug 6, 2025
0a111d0
clang
evkotov Aug 6, 2025
f826f34
move back transformations
evkotov Aug 6, 2025
5366bc9
clang fixes
evkotov Aug 6, 2025
7c4e8b2
use Manager to get private PassConfig
evkotov Aug 6, 2025
96890c0
fix SymbolicOptimization
evkotov Aug 6, 2025
0466ad8
clang fixes
evkotov Aug 6, 2025
941de2b
add comment due to code review
evkotov Aug 7, 2025
d8f3764
return TODO as comment review fix
evkotov Aug 7, 2025
3e78140
code review fix
evkotov Aug 7, 2025
6ae3a1f
clang fix
evkotov Aug 7, 2025
6ffdc1a
fix ov_cpu_func_tests smoke_RoPETestChatGLMSlice/RoPETestChatGLMSlice…
evkotov Aug 8, 2025
4d04a25
pytorch layer test fix
evkotov Aug 9, 2025
f3a5c41
Fix PassConfig State Management in SymbolicOptimizationsFix PassConfi…
evkotov Aug 12, 2025
0b2dfb2
Merge remote-tracking branch 'origin/master' into CVS-170030
evkotov Aug 12, 2025
14ace40
use PassConfig copying constructor
evkotov Aug 12, 2025
c645293
clang fix
evkotov Aug 12, 2025
6846eb8
cleanup
evkotov Aug 12, 2025
f3af756
Merge branch 'master' into CVS-170030
evkotov Aug 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,23 @@ bool ov::pass::SymbolicOptimizations::run_on_model(const std::shared_ptr<ov::Mod
// it may break NNCF patterns and lead to unexpected FakeQuantize ops in the model.
// So we decided to disable these passes in SymbolicOptimizations.
const auto& pass_config = m_manager->get_pass_config();

const auto old_pass_config = *pass_config;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the make a copy of config at L:212:
const auto pass_config = m_manager->get_pass_config();

then disable some passes on this copy, in case of error the restoration of original pass config should not be required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, there is no method Manager::set_pass_config. Should we add this method?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The set pass should not be require the idea was to refactor like:

    // make a copy of config
    const auto pass_config = m_manager->get_pass_config();
    
    // disable passes on copy only
    pass_config->disable<EliminateSqueeze>();
    pass_config->disable<EliminateUnsqueeze>();
    
    auto result = m_manager->run_passes(m);
    
    ov::remove_skip_invalidation_rti(m);
    return result;

but maybe there is something missed and it will not work correctly


// Temporarily disable passes for SymbolicOptimizations execution
pass_config->disable<EliminateSqueeze>();
pass_config->disable<EliminateUnsqueeze>();

m_manager->run_passes(m);
bool result;
try {
result = m_manager->run_passes(m);
} catch (...) {
*pass_config = old_pass_config; // Restore original pass config on exception
throw; // Re-throw the exception
}

*pass_config = old_pass_config;

ov::remove_skip_invalidation_rti(m);
return true;
return result;
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
#include <cstddef>
#include <cstdint>
#include <memory>
#include <openvino/core/rt_info.hpp>
#include <openvino/pass/manager.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include <transformations/utils/gen_pattern.hpp>
#include <tuple>
#include <vector>

Expand All @@ -23,7 +19,9 @@
#include "openvino/core/node.hpp"
#include "openvino/core/node_output.hpp"
#include "openvino/core/node_vector.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/core/type.hpp"
#include "openvino/core/type/element_type.hpp"
#include "openvino/op/assign.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
Expand All @@ -36,17 +34,19 @@
#include "openvino/pass/matcher_pass.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/common_optimizations/simplify_shape_of_sub_graph.hpp"
#include "transformations/cpu_opset/common/op/sdpa.hpp"
#include "transformations/defs.hpp"
#include "transformations/symbolic_transformations/symbolic_optimizations.hpp"
#include "transformations/transpose_sinking/ts_shape_of.hpp"
#include "transformations/utils/utils.hpp"

#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
# include "transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.hpp"
#endif

using namespace ov::gen_pattern;
using namespace ov::pass;

namespace ov::intel_cpu {
Expand All @@ -55,28 +55,24 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
MATCHER_SCOPE(StatefulSDPAFusion);
using namespace ov::pass::pattern;

auto beam_idx = makePattern("i32[?]");
auto beam_idx = any_input(type_matches(element::i32) && shape_matches("[?]"));
auto cur_q = any_input();
auto cur_k = any_input();
auto cur_v = any_input();

auto axis_seq_len = ov::gen_pattern::Symbol("axis_seq_len");
auto axis_beam = ov::gen_pattern::Symbol("axis_beam");
auto past_k = wrap_type<ov::op::v6::ReadValue>();
auto past_v = wrap_type<ov::op::v6::ReadValue>();

// past_kv can be BHLS/LBHS
auto past_k = makePattern<ov::op::v6::ReadValue>({});
auto past_v = makePattern<ov::op::v6::ReadValue>({});

auto convert_past_k = makePattern<ov::op::v0::Convert>({past_k});
auto convert_past_v = makePattern<ov::op::v0::Convert>({past_v});
auto convert_past_k = wrap_type<ov::op::v0::Convert>({past_k});
auto convert_past_v = wrap_type<ov::op::v0::Convert>({past_v});

auto gather_input_k =
makePattern<ov::op::v8::Gather>({past_k | convert_past_k, beam_idx, axis_beam}, {{"batch_dims", 0}});
wrap_type<ov::op::v8::Gather>({past_k | convert_past_k, beam_idx, "axis_beam"}, {{"batch_dims", 0}});
auto gather_input_v =
makePattern<ov::op::v8::Gather>({past_v | convert_past_v, beam_idx, axis_beam}, {{"batch_dims", 0}});
wrap_type<ov::op::v8::Gather>({past_v | convert_past_v, beam_idx, "axis_beam"}, {{"batch_dims", 0}});

auto concat_k = makePattern<ov::op::v0::Concat>({gather_input_k, cur_k}, {{"axis", axis_seq_len}});
auto concat_v = makePattern<ov::op::v0::Concat>({gather_input_v, cur_v}, {{"axis", axis_seq_len}});
auto concat_k = wrap_type<ov::op::v0::Concat>({gather_input_k, cur_k});
auto concat_v = wrap_type<ov::op::v0::Concat>({gather_input_v, cur_v});

std::shared_ptr<Node> mq_reshape_k, reshape_k, unsqueeze_k, computed_bcst_k, multiply_k, computed_bcst3_k;
std::tie(mq_reshape_k, reshape_k, unsqueeze_k, computed_bcst_k, multiply_k, computed_bcst3_k) =
Expand All @@ -88,32 +84,35 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
auto present_v = concat_v | mq_reshape_v;

// canonical q/k/v shape definition: [B,H,...L,S]
auto sdp0 = makePattern<ov::op::v13::ScaledDotProductAttention>({cur_q, present_k, present_v});
auto sdp1 = makePattern<ov::op::v13::ScaledDotProductAttention>({cur_q, present_k, present_v, any_input()});
auto sdp0 = wrap_type<ov::op::v13::ScaledDotProductAttention>({cur_q, present_k, present_v});
auto sdp1 = wrap_type<ov::op::v13::ScaledDotProductAttention>({cur_q, present_k, present_v, any_input()});
auto sdp2 =
makePattern<ov::op::v13::ScaledDotProductAttention>({cur_q, present_k, present_v, any_input(), any_input()});
wrap_type<ov::op::v13::ScaledDotProductAttention>({cur_q, present_k, present_v, any_input(), any_input()});

// non-canonical q/k/v shape definitions, for example: [L, B, H, S]/[B, L, H, S]
auto order_k = wrap_type<ov::op::v0::Constant>();
auto order_v = wrap_type<ov::op::v0::Constant>();
auto order_q = wrap_type<ov::op::v0::Constant>();
auto transpose_q = makePattern<ov::op::v1::Transpose>({cur_q, order_q});
auto transpose_k = makePattern<ov::op::v1::Transpose>({present_k, order_k});
auto transpose_v = makePattern<ov::op::v1::Transpose>({present_v, order_v});
auto transpose_q = wrap_type<ov::op::v1::Transpose>({cur_q, order_q});
auto transpose_k = wrap_type<ov::op::v1::Transpose>({present_k, order_k});
auto transpose_v = wrap_type<ov::op::v1::Transpose>({present_v, order_v});

auto sdp_trans0 = makePattern<ov::op::v13::ScaledDotProductAttention>({transpose_q, transpose_k, transpose_v});
auto sdp_trans0 = wrap_type<ov::op::v13::ScaledDotProductAttention>({transpose_q, transpose_k, transpose_v});
auto sdp_trans1 =
makePattern<ov::op::v13::ScaledDotProductAttention>({transpose_q, transpose_k, transpose_v, any_input()});
auto sdp_trans2 = makePattern<ov::op::v13::ScaledDotProductAttention>(
wrap_type<ov::op::v13::ScaledDotProductAttention>({transpose_q, transpose_k, transpose_v, any_input()});
auto sdp_trans2 = wrap_type<ov::op::v13::ScaledDotProductAttention>(
{transpose_q, transpose_k, transpose_v, any_input(), any_input()});

auto sdp = sdp0 | sdp1 | sdp2 | sdp_trans0 | sdp_trans1 | sdp_trans2;

ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto root = m.get_match_root();
PatternValidator validator(m);
if (!validator) {

// Check concat axes equality first
const auto concat_k_node = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at(concat_k).get_node_shared_ptr());
const auto concat_v_node = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at(concat_v).get_node_shared_ptr());
if (concat_k_node->get_axis() != concat_v_node->get_axis()) {
return false;
}

Expand Down Expand Up @@ -155,9 +154,6 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
if (!check_valid_children_type(past_k_node) || !check_valid_children_type(past_v_node)) {
return false;
}
const auto concat_k_node = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at(concat_k).get_node_shared_ptr());
const auto concat_v_node = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at(concat_v).get_node_shared_ptr());

for (auto&& item : {concat_k_node, concat_v_node}) {
auto&& children = item->get_output_target_inputs(0);
switch (children.size()) {
Expand Down Expand Up @@ -301,16 +297,16 @@ StatefulSDPAFusion::StatefulSDPAFusion() {

bool SDPASubgraphFusion::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(SDPASubgraphFusion);
ov::pass::Manager manager("SDPASubgraphFusion");
ov::pass::SymbolicOptimizations symbolic_optimizations(false, get_pass_config());
auto& ctx_manager = *symbolic_optimizations.get_manager();

CPU_REGISTER_PASS_COMMON(manager, ov::pass::SimplifyGatherShapeOf);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::transpose_sinking::TSShapeOfForward);
CPU_REGISTER_PASS_COMMON(manager, StatefulSDPAFusion);
CPU_REGISTER_PASS_COMMON(ctx_manager, ov::pass::SimplifyGatherShapeOf);
CPU_REGISTER_PASS_COMMON(ctx_manager, ov::pass::transpose_sinking::TSShapeOfForward);
CPU_REGISTER_PASS_COMMON(ctx_manager, StatefulSDPAFusion);
// TODO: remove the following after snippets support patterns with dynamic shapes
CPU_REGISTER_PASS_X64(manager, ov::intel_cpu::SDPAFuseTransposeReshape);
CPU_REGISTER_PASS_X64(ctx_manager, ov::intel_cpu::SDPAFuseTransposeReshape);

manager.run_passes(f);
return false;
return symbolic_optimizations.run_on_model(f);
}

} // namespace ov::intel_cpu
Loading